1997 LongShortTermMemory

From GM-RKB
Jump to navigation Jump to search

Subject Headings: Long Short Term Memory.

Notes

Cited By

Quotes

Abstract

Learning to store information over extended time intervals by recurrent backpropagation takes a very long time, mostly because of insufficient, decaying error backflow. We briefly review Hochreiter's (1991) analysis of this problem, then address it by introducing a novel, efficient, gradient based method called long short-term memory (LSTM). Truncating the gradient where this does not do harm, LSTM can learn to bridge minimal time lags in excess of 1000 discrete time steps by enforcing constant error flow through constant error carousels within special units. Multiplicative gate units learn to open and close access to the constant error flow. LSTM is local in space and time; its computational complexity per time step and weight is O.1. Our experiments with artificial data involve local, distributed, real-valued, and noisy pattern representations. In comparisons with real-time recurrent learning, back propagation through time, recurrent cascade correlation, Elman nets, and neural sequence chunking, LSTM leads to many more successful runs, and learns much faster. LSTM also solves complex, artificial long-time-lag tasks that have never been solved by previous recurrent network algorithms.

1. Introduction

Recurrent networks can in principle use their feedback connections to store representations of recent input events in form of activations ("short-term memory", as opposed to “long-term memory” embodied by slowly changing weights). This is potentially significant for many applications, including speech processing, non-Markovian control, and music composition (e.g., Mozer 1992). The most widely used algorithms for learning what to put in short-term memory, however, take too much time or do not work well at all, especially when minimal time lags between inputs and corresponding teacher signals are long. Although theoretically fascinating, existing methods do not provide clear practical advantages over, say, backprop in feedforward nets with limited time windows. This paper will review an analysis of the problem and suggest a remedy.

The problem. With conventional "Back-Propagation Through Time" (BPTT, e.g., Williams and Zipser 1992, Werbos 1988) or "Real-Time Recurrent Learning" (RTRL, e.g., Robinson and Fallside 1987), error signals "owing backwards in time" tend to either (1) blow up or (2) vanish: the temporal evolution of the backpropagated error exponentially depends on the size of the weights (Hochreiter 1991). Case (1) may lead to oscillating weights, while in case (2) learning to bridge long time lags takes a prohibitive amount of time, or does not work at all (see section 3).

The remedy. This paper presents "Long Short-Term Memory" (LSTM), a novel recurrent network architecture in conjunction with an appropriate gradient-based learning algorithm. LSTM is designed to overcome these error back-flow problems. It can learn to bridge time intervals in excess of 1000 steps even in case of noisy, incompressible input sequences, without loss of short time lag capabilities. This is achieved by an efficient, gradient-based algorithm for an architecture enforcing constant (thus neither exploding nor vanishing) error flow through internal states of special units (provided the gradient computation is truncated at certain architecture-specific points - this does not affect long-term error flow though).

Outline of paper. Section 2 will briey review previous work. Section 3 begins with an outline of the detailed analysis of vanishing errors due to Hochreiter (1991). It will then introduce a naive approach to constant error backprop for didactic purposes, and highlight its problems concerning information storage and retrieval. These problems will lead to the LSTM architecture as described in Section 4. Section 5 will present numerous experiments and comparisons with competing methods. LSTM outperforms them, and also learns to solve complex, artificial tasks no other recurrent net algorithm has solved. Section 6 will discuss LSTM's limitations and advantages. The appendix contains a detailed description of the algorithm (A.1), and explicit error flow formulae #(A.2).

2. Previous Work

This section will focus on recurrent nets with time-varying inputs (as opposed to nets with stationary inputs and fixpoint-based gradient calculations, e.g., Almeida 1987, Pineda 1987).

Gradient-descent variants. The approaches of Elman (1988), Fahlman (1991), Williams (1989), Schmidhuber (1992a), Pearlmutter (1989), and many of the related algorithms in Pearlmutter's comprehensive overview (1995) suffer from the same problems as BPTT and RTRL (see Sections 1 and 3).

Time-delays. Other methods that seem practical for short time lags only are Time-Delay Neural Networks (Lang et al. 1990) and Plate's method (Plate 1993), which updates unit activations based on a weighted sum of old activations (see also de Vries and Principe 1991). Lin et al. (1995) propose variants of time-delay networks called NARX networks.

Time constants. To deal with long time lags, Mozer (1992) uses time constants inuencing changes of unit activations (deVries and Principe's above-mentioned approach (1991) may in fact be viewed as a mixture of TDNN and time constants). For long time lags, however, the time constants need external fine tuning (Mozer 1992). Sun et al.'s alternative approach (1993) updates the activation of a recurrent unit by adding the old activation and the (scaled) current net input. The net input, however, tends to perturb the stored information, which makes long-term storage impractical.

Ring's approach. Ring (1993) also proposed a method for bridging long time lags. Whenever a unit in his network receives conicting error signals, he adds a higher order unit inuencing appropriate connections. Although his approach can sometimes be extremely fast, to bridge a time lag involving 100 steps may require the addition of 100 units. Also, Ring's net does not generalize to unseen lag durations.

Bengio et al.'s approaches. Bengio et al. (1994) investigate methods such as simulated annealing, multi-grid random search, time-weighted pseudo-Newton optimization, and discrete error propagation. Their "latch" and "2-sequence" problems are very similar to problem 3a with minimal time lag 100 (see Experiment 3). Bengio and Frasconi (1994) also propose an EM approach for propagating targets. With n so-called “state networks", at a given time, their system can be in one of only n different states. See also beginning of Section 5. But to solve continuous problems such as the "adding problem" (Section 5.4), their system would require an unacceptable number of states (i.e., state networks).

 Kalman filters. Puskorius and Feldkamp (1994) use Kalman filter techniques to improve recurrent net performance. Since they use "a derivative discount factor imposed to decay exponentially the effects of past dynamic derivatives", there is no reason to believe that their Kalman Filter Trained Recurrent Networks will be useful for very long minimal time lags.

Second order nets. We will see that LSTM uses multiplicative units (MUs) to protect error flow from unwanted perturbations. It is not the first recurrent net method using MUs though. For instance, Watrous and Kuhn (1992) use MUs in second order nets. Some differences to LSTM are: (1) Watrous and Kuhn's architecture does not enforce constant error flow and is not designed to solve long time lag problems. (2) It has fully connected second-order sigma-pi units, while the LSTM architecture's MUs are used only to gate access to constant error flow. (3) Watrous and Kuhn's algorithm costs O(W2) operations per time step, ours only O (W), where W is the number of weights. See also Miller and Giles (1993) for additional work on MUs.

Simple weight guessing. To avoid long time lag problems of gradient-based approaches we may simply randomly initialize all network weights until the resulting net happens to classify all training sequences correctly. In fact, recently we discovered (Schmidhuber and Hochreiter 1996, Hochreiter and Schmidhuber 1996, 1997) that simple weight guessing solves many of the problems in (Bengio 1994, Bengio and Frasconi 1994, Miller and Giles 1993, Lin et al. 1995) faster than the algorithms proposed therein. This does not mean that weight guessing is a good algorithm. It just means that the problems are very simple. More realistic tasks require either many free parameters (e.g., input weights) or high weight precision (e.g., for continuous-valued parameters), such that guessing becomes completely infeasible.

 Adaptive sequence chunkers. Schmidhuber's hierarchical chunker systems (1992b, 1993) do have a capability to bridge arbitrary time lags, but only if there is local predictability across the subsequences causing the time lags (see also Mozer|Mozer 1992). For instance, in his postdoctoral thesis (1993), Schmidhuber uses hierarchical recurrent nets to rapidly solve certain grammar learning tasks involving minimal time lags in excess of 1000 steps. The performance of chunker systems, however, deteriorates as the noise level increases and the input sequences become less compressible. LSTM does not suffer from this problem.

3. Constant Error Backprop

3.1 Exponentially Decaying Error

Conventional BPTT (e.g. Williams and Zipser 1992). Output unit $k$'s target at time $t$ is denoted by $d_k (t)$. Using mean squared error, $k$'s error signal is

$\vartheta_{k}(t)=f_{k}^{\prime}\left(\operatorname{net}_{k}(t)\right)\left(d_{k}(t)-y^{k}(t)\right)$

where

$y^{i}(t)=f_{i}\left(\text { net }_{i}(t)\right)$

is the activation of a non-input unit $i$ with differentiable activation function $f_i$

$\operatorname{net}_{i}(t)=\displaystyle \sum_j w_{i j} y^{j}(t-1)$

is unit $i$'s current net input, and $w_{ij}$ is the weight on the connection from unit $j$ to $i$. Some non-output unit $j$'s backpropagated error signal

$\vartheta_{j}(t)=f_{j}^{\prime}\left(\operatorname{net}_{j}(t)\right)\displaystyle \sum_{i} w_{i j} \vartheta_{i}(t+1)$

The corresponding contribution to $w_{jl}$ 's total weight update is $\alpha\vartheta_{j}(t)y^{l} (t-1)$, where $\alpha$ is the learning rate, and $l$ stands for an arbitrary unit connected to unit $j$.

Outline of Hochreiter's analysis (1991, page 19-21). Suppose we have a fully connected net whose non-input unit indices range from 1 to n. Let us focus on local error flow from unit $u$ to unit $v$ (later we will see that the analysis immediately extends to global error flow). The error occurring at an arbitrary unit $u$ at time step $t$ is propagated "back into time" for $q$ time steps, to an arbitrary unit $v$. This will scale the error by the following factor:

$\dfrac{\partial \vartheta_{v}(t-q)}{\partial \vartheta_{u}(t)}=\left\{\begin{array}{cc} f_{v}^{\prime}\left(\operatorname{net}_{v}(t-1)\right) w_{u v} & q=1 \\ f_{v}^{\prime}\left(\operatorname{net}_{v}(t-q)\right) \sum_{l=1}^{n} \dfrac{\partial \vartheta_{1}(t-q+1)}{\partial \theta_{u}(t)} w_{l v} & q>1 \end{array}\right.$

(1)

With $l_q = v$ and $l_0 = u$, we obtain:

$\dfrac{\partial \vartheta_{v}(t-q)}{\partial \vartheta_{u}(t)}=\displaystyle\sum_{l_{1}=1}^{n} \ldots \sum_{l_{\mathrm{g}-1}=1}^{n} \prod_{m=1}^{0} f_{l_{m}}^{\prime}\left(\operatorname{net}_{l_{m}}(t-m)\right) w_{l_{m} l_{m-1}}$

(2)

(proof by induction). The sum of the $n^{q-1}$ terms $\prod_{m=1}^{0} f_{l_{m}}^{\prime}\left(\operatorname{net}_{l_{m}}(t-m)\right) w_{l_{m} l_{m-1}}$ determines the total error back flow (note that since the summation terms may have different signs, increasing the number of units $n$ does not necessarily increase error flow).

Intuitive explanation of equation (2). If

$\left|f_{l_{m}}^{\prime}\left(\operatorname{net}_{l_{m}}(t-m)\right) w_{l_{m} l_{m-1}}\right|>1.0$

for all $m$ (as can happen, e.g., with linear $f_{lm}$ ) then the largest product increases exponentially with $q$. That is, the error blows up, and conicting error signals arriving at unit $v$ can lead to oscillating weights and unstable learning (for error blow-ups or bifurcations see also Pineda 1988, Baldi and Pineda 1991, Doya 1992). On the other hand, if

$\left|f_{l_{m}}^{\prime}\left(\operatorname{net}_{l_{m}}(t-m)\right) w_{l_{m} l_{m-1}}\right|<1.0$

for all $m$, then the largest product decreases exponentially with $q$. That is, the error vanishes, and nothing can be learned in acceptable time.

If $f_{lm}$ is the logistic sigmoid function, then the maximal value of $f_{lm}^{\prime}$ is 0.25. If $y^l_{m-1}$ is constant and not equal to zero, then $\left|f_{l_{m}}^{\prime}\left(\operatorname{net}_{l_{m}}\right) w_{l_{m} l_{m-1}}\right|$ j takes on maximal values where

$w_{l_{m} l_{m-1}}=\frac{1}{y^{l_{m}-1}} \operatorname{coth}\left(\frac{1}{2} n e t_{l_{m}}\right)$,

goes to zero for $\left|w_{l_{m} l_{m-1}}\right| \to \infty$, and is less than 1.0 for $\left|w_{l_{m} l_{m-1}}\right| < 4.0$ (e.g., if the absolute maximal weight value $w_{max}$ is smaller than 4.0). Hence with conventional logistic sigmoid activation functions, the error flow tends to vanish as long as the weights have absolute values below 4.0, especially in the beginning of the training phase. In general the use of larger initial weights will not help though - as seen above, for $\left|w_{l_{m} l_{m-1}}\right|\to \infty$ the relevant derivative goes to zero "faster" than the absolute weight can grow (also, some weights will have to change their signs by crossing zero). Likewise, increasing the learning rate does not help either - it will not change the ratio of long-range error flow and short-range error flow. BPTT is too sensitive to recent distractions. (A very similar, more recent analysis was presented Bengio et al. 1994).

Global error flow. The local error flow analysis above immediately shows that global error flow vanishes, too. To see this, compute

$\displaystyle \sum_{u:\;u\;output\;unit}\dfrac{\partial\vartheta_v\left(t-q\right)}{\partial\vartheta_u\left(t\right)}$.

Weak Upper Bound for Scaling Factor

3.2 Constant Error Flow: Naive

4. Long Short-Term Memory

5. Experiments

6. Discussion

7. Conclusion

8. Acknowledgments

Appendix

References

BibTeX

@article{1997_LongShortTermMemory,
  author    = {Sepp Hochreiter and
               J{\"{u}}rgen Schmidhuber},
  title     = {Long Short-Term Memory},
  journal   = {Neural Computation},
  volume    = {9},
  number    = {8},
  pages     = {1735--1780},
  year      = {1997},
  url       = {http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.676.4320&rep=rep1&type=pdf},
  doi       = {10.1162/neco.1997.9.8.1735},
}


 AuthorvolumeDate ValuetitletypejournaltitleUrldoinoteyear
1997 LongShortTermMemoryJürgen Schmidhuber
Sepp Hochreiter
Long Short-term Memory1997