- 快召唤伙伴们来围观吧
- 微博 QQ QQ空间 贴吧
- 文档嵌入链接
- 复制
- 微信扫一扫分享
- 已成功复制到剪贴板
Learning Efficient Algorithms with Hierarchical Attentive Memory
展开查看详情
1 . Learning Efficient Algorithms with Hierarchical Attentive Memory Marcin Andrychowicz∗ MARCINA @ GOOGLE . COM Google DeepMind Karol Kurach∗ KKURACH @ GOOGLE . COM Google / University of Warsaw1 ∗ equal contribution arXiv:1602.03218v2 [cs.LG] 23 Feb 2016 Abstract practice, this limits the number of used memory cells to few thousands. In this paper, we propose and investigate a novel memory architecture for neural networks called It would be desirable for the size of the memory to be inde- Hierarchical Attentive Memory (HAM). It is pendent of the number of model parameters. The first ver- based on a binary tree with leaves corresponding satile and highly successful architecture with this property to memory cells. This allows HAM to perform was Neural Turing Machine (NTM) (Graves et al., 2014). memory access in Θ(log n) complexity, which The main idea behind the NTM is to split the network into a is a significant improvement over the standard trainable “controller” and an “external” variable-size mem- attention mechanism that requires Θ(n) opera- ory. It caused an outbreak of other neural network architec- tions, where n is the size of the memory. tures with external memories (see Sec. 2). We show that an LSTM network augmented with However, one aspect which has been usually neglected so HAM can learn algorithms for problems like far is the efficiency of the memory access. Most of the merging, sorting or binary searching from pure proposed memory architectures have the Θ(n) access com- input-output examples. In particular, it learns to plexity, where n is the size of the memory. It means that, sort n numbers in time Θ(n log n) and general- for instance, copying a sequence of length n requires per- izes well to input sequences much longer than the forming Θ(n2 ) operations, which is clearly unsatisfactory. ones seen during the training. We also show that HAM can be trained to act like classic data struc- 1.1. Our contribution tures: a stack, a FIFO queue and a priority queue. In this paper we propose a novel memory module for neural networks, called Hierarchical Attentive Memory (HAM). 1. Intro The HAM module is generic and can be used as a build- ing block of larger neural architectures. Its crucial property Deep Recurrent Neural Networks (RNNs) have recently is that it scales well with the memory size — the memory proven to be very successful in real-word tasks, e.g. ma- access requires only Θ(log n) operations, where n is the chine translation (Sutskever et al., 2014) and computer vi- size of the memory. This complexity is achieved by us- sion (Vinyals et al., 2014). However, the success has been ing a new attention mechanism based on a binary tree with achieved only on tasks which do not require a large mem- leaves corresponding to memory cells. The novel attention ory to solve the problem, e.g. we can translate sentences mechanism is not only faster than the standard one used in using RNNs, but we can not produce reasonable transla- Deep Learning (Bahdanau et al., 2014), but it also facilities tions of really long pieces of text, like books. learning algorithms due to a built-in bias towards operating A high-capacity memory is a crucial component neces- on intervals. sary to deal with large-scale problems that contain plenty We show that an LSTM augmented with HAM is able to of long-range dependencies. Currently used RNNs do not learn algorithms for tasks like merging, sorting or binary scale well to larger memories, e.g. the number of parame- searching. In particular, it is the first neural network, which ters in an LSTM (Hochreiter & Schmidhuber, 1997) grows we are aware of, that is able to learn to sort from pure input- quadratically with the size of the network’s memory. In output examples and generalizes well to input sequences 1 much longer than the ones seen during the training. More- Work done while at Google. over, the learned sorting algorithm runs in time Θ(n log n). We also show that the HAM memory itself is capable of simulating different classic memory structures: a stack, a FIFO queue and a priority queue.
2 . Learning Efficient Algorithms with Hierarchical Attentive Memory 2. Related work els is that they allow a constant time memory access. They were however only successful on relatively simple tasks. In this section we mention a number of recently proposed neural architectures with an external memory, which size is Another model, which can use a pointer-based memory independent of the number of the model parameters. is the Neural Programmer-Interpreter (Reed & de Freitas, 2015). It is very interesting, because it managed to learn Memory architectures based on attention Attention is sub-procedures. Unfortunately, it requires strong supervi- a recent but already extremely successful technique in sion in the form of execution traces. Deep Learning. This mechanism allows networks to at- Another type of pointer-based memory was presented tend to parts of the (potentially preprocessed) input se- in Neural Random-Access Machine (Kurach et al., 2015), quence (Bahdanau et al., 2014) while generating the out- which is a neural architecture mimicking classic comput- put sequence. It is implemented by giving the network as ers. an auxiliary input a linear combination of input symbols, where the weights of this linear combination can be con- Parallel memory architectures There are two recent trolled by the network. memory architectures, which are especially suited for Attention mechanism was used to access the memory in parallel computation. Grid-LSTM (Kalchbrenner et al., Neural Turing Machines (NTMs) (Graves et al., 2014). It 2015) is an extension of LSTM to multiple dimen- was the first paper, that explicitly attempted to train a com- sions. Another recent model of this type is Neural GPU putationally universal neural network and achieved encour- (Kaiser & Sutskever, 2015), which can learn to multiply aging results. long binary numbers. The Memory Network (Weston et al., 2014) is an early model that attempted to explicitly separate the memory 3. Hierarchical Attentive Memory from computation in a neural network model. The followup In this section we describe our novel memory module work of (Sukhbaatar et al., 2015) combined the memory called Hierarchical Attentive Memory (HAM). The HAM network with the soft attention mechanism, which allowed module is generic and can be used as a building block of it to be trained with less supervision. In contrast to NTMs, larger neural network architectures. For instance, it can be the memory in these models is non-writeable. added to feedforward or LSTM networks to extend their ca- Another model without writeable memory is the Pointer pabilities. To make our description more concrete we will Network (Vinyals et al., 2015), which is very similar to the consider a model consisting of an LSTM “controller” ex- attention model of Bahdanau et al. (2014). Despite not hav- tended with a HAM module. ing a memory, this model was able to solve a number of The high-level idea behind the HAM module is as follows. difficult algorithmic problems that include the Convex Hull The memory is structured as a full binary tree with the and the approximate 2D Travelling Salesman Problem. leaves containing the data stored in the memory. The in- All of the architectures mentioned so far use standard at- ner nodes contain some auxiliary data, which allows us to tention mechanisms to access the memory and therefore efficiently perform some types of “queries” on the mem- memory access complexity scales linearly with the mem- ory. In order to access the memory, one starts from the ory size. root of the tree and performs a top-down descent in the tree, which is similar to the hierarchical softmax procedure Memory architectures based on data structures Stack- (Morin & Bengio, 2005). At every node of the tree, one Augmented Recurrent Neural Network (Joulin & Mikolov, decides to go left or right based on the auxiliary data stored 2015) is a neural architecture combining an RNN and a in this node and a “query”. Details are provided in the rest differentiable stack. In another paper (Grefenstette et al., of this section. 2015) authors consider extending an LSTM with a stack, a FIFO queue or a double-ended queue and show some 3.1. Notation promising results. The advantage of the latter model is that The model takes as input a sequence x1 , x2 , . . . and out- the presented data structures have a constant access time. puts a sequence y1 , y2 , . . .. We assume that each element of these sequences is a binary vector of size b ∈ N, i.e. Memory architectures based on pointers In two recent xi , yi ∈ {0, 1}b. Suppose for a moment that we only want papers (Zaremba & Sutskever, 2015; Zaremba et al., 2015) to process input sequences of length ≤ n, where n ∈ N is authors consider extending neural networks with nondif- a power of two (we show later how to process sequences of ferentiable memories based on pointers and trained using an arbitrary length). The model is based on the full binary Reinforcement Learning. The big advantage of these mod- tree with n leaves. Let V denote the set of the nodes in that
3 . Learning Efficient Algorithms with Hierarchical Attentive Memory y1 y2 y3 h1 JOIN h2 h3 LSTM LSTM LSTM JOIN JOIN h4 h5 h6 h7 ... JOIN JOIN JOIN JOIN h8 h9 h10 h11 h12 h13 h14 h15 HAM HAM HAM EMBED EMBED EMBED EMBED EMBED EMBED x1 ... xm x1 x2 x3 x4 x5 x6 Figure 1. The LSTM+HAM model consists of an LSTM con- Figure 2. Initialization of the model. The value in the i-th leaf of troller and a HAM module. The execution of the model starts HAM is initialized with EMBED(xi ), where EMBED is a train- with the initialization of HAM using the whole input sequence able feed-forward network. If there are more leaves than input x1 , x2 , . . . , xm . At each timestep, the HAM module produces symbols, we initialize the values in the excessive leaves with ze- an input for the LSTM, which then produces an output symbol ros. Then, we initialize the values in the inner nodes bottom-up yt . Afterwards, the hidden states of the LSTM and HAM are up- using the formula he = JOIN(hl(e) , hr(e) ). The hidden state of dated. the LSTM — hLSTM is initialized with zeros. h1 SEARCH(h1 , hLSTM ) = 0.95 tree (notice that |V | = 2n − 1) and let L ⊂ V denote the set of its leaves. Let l(e) for e ∈ V \ L be the left child of h2 h3 SEARCH(h3 , hLSTM ) = 0.1 the node e and let r(e) be its right child. SEARCH(h6 , hLSTM ) = 1 We will now present the inference procedure for the model h4 h5 h6 h7 and then discuss how to train it. h8 h9 h10 h11 h12 ha h14 h15 3.2. Inference The high-level view of the model execution is presented in Figure 3. Attention phase. In this phase the model performs a top- Fig. 1. The hidden state of the model consists of two com- down “search” in the tree starting from the root. Suppose that ponents: the hidden state of the LSTM controller (denoted we are currently at the node c ∈ V \ L. We compute the value hLSTM ∈ Rl for some l ∈ N) and the hidden values stored p = SEARCH(hc , hLSTM ). Then, with probability p the model in the nodes of the HAM tree. More precisely, for every goes right (i.e. c := r(c)) and with probability 1 − p it goes left node e ∈ V there is a hidden value he ∈ Rd . These values (i.e. c := l(c)). This procedure is continued until we reach one change during the recurrent execution of the model, but we of the leaves. This leaf is called the attended or accessed leaf and drop all timestep indices to simplify the notation. denoted a. The parameters of the model describe the input-output be- haviour of the LSTM, as well as the following 4 trans- The HAM parameters describe only the 4 mentioned trans- formations, which describe the HAM module: EMBED : formations and hence the number of the model parameters Rb → Rd , JOIN : Rd × Rd → Rd , SEARCH : Rd × Rl → does not depend on the size of the binary tree used. Thus, [0, 1] and WRITE : Rd × Rl → Rd . These transforma- we can use the model to process the inputs of an arbitrary tions may be represented by arbitrary function approxima- length by using big enough binary trees. It is not clear that tors, e.g. Multilayer Perceptrons (MLPs). Their meaning the same set of parameters will give good results across will be described soon. different tree sizes, but we showed experimentally that it is indeed the case (see Sec. 4 for more details). The details of the model are presented in 4 figures. Fig. 2 describes the initialization of the model. Each recurrent We decided to represent the transformations defining HAM timestep of the model consists of three phases: the attention with MLPs with ReLU (Nair & Hinton, 2010) activation phase described in Fig. 3, the output phase described in function in all neurons except the output layer of SEARCH, Fig. 4 and the update phase described in Fig. 5. The whole which uses sigmoid activation function to ensure that timestep can be performed in time Θ(log n). the output may be interpreted as a probability. More-
4 . Learning Efficient Algorithms with Hierarchical Attentive Memory whether to go left or right made during the whole execu- tion of the model. We would like to maximize the log- ha hLSTM yt probability of producing the correct output, i.e. Figure 4. Output phase. The value ha stored in the attended leaf L = log p(y|x, θ) = log p(A|x, θ)p(y|A, x, θ) . is given to the LSTM as an input. Then, the LSTM produces an A output symbol yt ∈ {0, 1}b . More precisely, the value u ∈ Rb is computed by a trainable linear transformation from hLSTM and This sum is intractable, so instead of minimizing it directly, the distribution of yt is defined by the formula p(yt,i = 1) = sigmoid(ui ) for 1 ≤ i ≤ b. It may be beneficial to allow the we minimize a variational lower bound on it: model to access the memory a few times between producing each output symbols. Therefore, the model produces an output symbol F= p(A|x, θ) log p(y|A, x, θ) ≤ L. only at timesteps with indices divisible by some constant η ∈ N, A which is a hyperparameter. This sum is also intractable, so we approximate its h1 gradient using the REINFORCE, which we briefly explain below. Using the identity ∇p(A|x, θ) = JOIN p(A|x, θ)∇ log p(A|x, θ), the gradient of the lower bound h2 h3 with respect to the model parameters can be rewritten as: JOIN h4 h5 h6 h7 ∇F = p(A|x, θ) ∇ log p(y|A, x, θ) + A JOIN h8 h9 h10 h11 h12 ha h14 h15 log p(y|A, x, θ)∇ log p(A|x, θ) (1) hLSTM We estimate this value using Monte Carlo approximation. ha := WRITE(ha , hLSTM ) For every x we sample A from p(A|x, θ) and approxi- mate the gradient for the input x as ∇ log p(y|A, x, θ) + Figure 5. Update phase. In this phase the value in the attended log p(y|A, x, θ)∇ log p(A|x, θ). leaf a is updated. More precisely, the value is modified us- ing the formula ha := WRITE(ha , hLSTM ). Then, we update Notice that this gradient estimate can be computed using the values of the inner nodes encountered during the attention normal backpropagation if we substitute the gradients in phase (h6 , h3 and h1 in the figure) bottom-up using the equation the nodes2 which sample whether we should go left or right he = JOIN(hl(e) , hr(e) ). during the attention phase by log p(y|A, x, θ) ∇ log p(A|x, θ). over, the network for WRITE is enhanced in a similar return way as Highway Networks (Srivastava et al., 2015), i.e. WRITE(ha , hLSTM ) = T (ha , hLSTM ) · H(ha , hLSTM ) + This term is called REINFORCE gradient estimate and the (1 − T (ha , hLSTM)) · ha , where H and T are two MLPs left factor is called a return in Reinforcement Learning lit- with sigmoid activation function in the output layer. This erature. This gradient estimator is unbiased, but it often allows the WRITE transformation to easily leave the value has a high variance. Therefore, we employ two standard ha unchanged. variance-reduction technique for REINFORCE: discounted returns and baselines (Williams, 1992). Discounted re- 3.3. Training turns means that our return at the t-th timestep has the In this section we describe how to train our model form t≤i γ i−t log p(yi |A, x, θ) for some discount con- from purely input-output examples using REINFORCE stant γ ∈ [0, 1], which is a hyperparameter. This biases (Williams, 1992). In Appendix A we also present a dif- the estimator if γ < 1, but it often decreases its variance. ferent variant of HAM which is fully differentiable and can For the lack of space we do not describe the baselines be trained using end-to-end backpropagation. technique. We only mention that our baseline is case and Let x, y be an input-output pair. Recall that both x and y 2 For a general discussion of computing gradients in computa- are sequences. Moreover, let θ denote the parameters of tion graphs, which contain stochastic nodes see (Schulman et al., the model and let A denote the sequence of all decisions 2015).
5 . Learning Efficient Algorithms with Hierarchical Attentive Memory timestep dependent: it is computed using a learnable lin- algorithm with exponentially decaying learning rate. We ear transformation from hLSTM and trained using MSE loss use random search to determine the best hyper-parameters function. for the model. We use gradient clipping (Pascanu et al., 2012) with constant 5. The depth of our MLPs is either 1 The whole model is trained with the Adam (Kingma & Ba, or 2, the LSTM controller has l = 20 memory cells and the 2014) algorithm. We also employ the following three train- hidden values in the tree have dimensionality d = 20. Con- ing techniques: stant η determining a number of memory accesses between producing each output symbols (Fig. 4) is equal either 1 Different reward function During our experiments we or 2. We always train for 100 epochs, each consisting of noticed that better results may be obtained by using a dif- 1000 batches of size 50. After each epoch we evaluate the ferent reward function for REINFORCE. More precisely, model on 200 validation batches without learning. When instead of the log-probability of producing the correct the training is finished, we select the model parameters that output, we use the percentage of the output bits, which gave the lowest error rate on validation batches and report have the probability of being predicted correctly (given the error using these parameters on fresh 2, 500 random ex- A) greater than 50%, i.e. our discounted return is equal amples. i−t t≤i,1≤j≤b γ p(yi,j |A, x, θ) > 0.5 . Notice that it We report two types of errors: a test error and a general- corresponds to the Hamming distance between the most ization error. The test error shows how well the model is probable outcome accordingly to the model (given A) and able to fit the data distribution and generalize to unknown the correct output. cases, assuming that cases of similar lengths were shown during the training. It is computed using the HAM memory Entropy bonus term We add a special term to the cost with n = 32 leaves, as the percentage of output sequences, function which encourages exploration. More precisely, for which were predicted incorrectly. The lengths of test exam- each sampling node we add to the cost function the term ples are sampled uniformly from the range [1, n]. Notice α H(p) , where H(p) is the entropy of the distribution of the that we mark the whole output sequence as incorrect even decision, whether to go left or right in this node and α is if only one bit was predicted incorrectly, e.g. a hypothetical an exponentially decaying coefficient. This term goes to model predicting each bit incorrectly with probability 1% infinity, whenever the entropy goes to zero, what ensures (and independently of the errors on the other bits) has an some level of exploration. We noticed that this term works error rate of 96% on whole sequences if outputs consist of better in our experiments than the standard term of the form 320 bits. −αH(p) (Williams, 1992). The generalization error shows how well the model per- forms with enlarged memory on examples with lengths ex- Curriculum schedule We start with training on inputs ceeding n. We test our model with memory 4 times bigger with lengths sampled uniformly from [1, n] for some n = than the training one. The lengths of input sequences are 2k and the binary tree with n leaves. Whenever the error now sampled uniformly from the range [2n + 1, 4n]. drops below some threshold, we increment the value k and start using the bigger tree with 2n leaves and inputs with During testing we make our model fully deterministic by lengths sampled uniformly from [1, 2n]. using the most probable outcomes instead of stochastic sampling. More precisely, we assume that during the at- tention phase the model decides to go right iff p > 0.5 4. Experiments (Fig. 3). Moreover, the output symbols (Fig. 4) are com- In this section, we evaluate two variants of using the HAM puted by rounding to zero or one instead of sampling. module. The first one is the model described in Sec. 3, which combines an LSTM controller with a HAM mod- 4.2. LSTM+HAM ule (denoted by LSTM+HAM). Then, in Sec. 4.3 we in- We evaluate the model on a number of algorithmic tasks vestigate the “raw” HAM (without the LSTM controller) described below: to check its capability of acting as classic data structures: a stack, a FIFO queue and a priority queue. Reverse: Given a sequence of 10-bit vectors, output them in the reversed order., i.e. yi = xm+1−i for 1 ≤ 4.1. Test setup i ≤ m, where m is the length of the input sequence. For each test that we perform, we apply the following pro- cedure. First, we train the model with memory of size Search: Given a sequence of pairs xi = keyi ||valuei up to n = 32 using the curriculum schedule described in for 1 ≤ i ≤ m − 1 sorted by keys and a query xm = q, find Sec. 3.3. The model is trained using the minibatch Adam the smallest i such that keyi = q and output y1 = valuei .
6 . Learning Efficient Algorithms with Hierarchical Attentive Memory Keys and values are 5-bit vectors and keys are compared eralizes very well to new sizes of the binary tree. We find lexicographically. The LSTM+HAM model is given only this fact quite interesting, because it means that parameters two timesteps (η = 2) to solve this problem, which forces learned from a small neural network (i.e. HAM based on a it to use a form of binary search. tree with 32 leaves) can be successfully used in a different, bigger network (i.e. HAM with 128 memory cells). Merge: Given two sorted sequences of pairs — In comparison, the LSTM with attention does not learn to (p1 , v1 ), . . . , (pm , vm ) and (p′1 , v1′ ), . . . , (p′m′ , vm ′ ′ ), where ′ ′ 5 merge, nor sort. It also completely fails to generalize to pi , pi ∈ [0, 1] and vi , vi ∈ {0, 1} , merge them. Pairs are longer examples, which shows that LSTM+A learns rather compared accordingly to their priorities, i.e. values pi and some statistical dependencies between inputs and outputs p′i . Priorities are unique and sampled uniformly from the 1 than the real algorithms. set { 300 , . . . , 300 300 }, because neural networks can not easily distinguish two real numbers which are very close to each The LSTM+HAM model makes a few errors when test- other. Input is encoded as xi = pi ||vi for 1 ≤ i ≤ m and ing on longer outputs than the ones encountered during xm+i = p′i ||vi′ for 1 ≤ i ≤ m′ . The output consists of the the training. Notice however, that we show in the table vectors vi and vi′ sorted accordingly to their priorities3 . the percentage of output sequences, which contain at least one incorrect bit. For instance, LSTM+HAM on the prob- Sort: Given a sequence of pairs xi = keyi ||valuei sort lem Merge predicts incorrectly only 0.03% of output bits, them in a stable way4 accordingly to the lexicographic or- which corresponds to 2.48% of incorrect output sequences. der of the keys. Keys and values are 5-bit vectors. We believe that these rare mistakes could be avoided if one trained the model longer and chose carefully the learning Add: Given two numbers represented in binary, rate schedule. One more way to boost generalization capa- compute their sum. The input is represented as bilities would be to simultaneously train the models with a1 , . . . , am , +, b1 , . . . , bm , = (i.e. x1 = a1 , x2 = a2 different memory sizes and shared parameters. We have and so on), where a1 , . . . , am and b1 , . . . , bm are bits of not tried this as the generalization properties of the model the input numbers and +, = are some special symbols. were already very good. Input and output numbers are encoded starting from the least significant bits. Table 1. Experimental results. The upper table presents the error Every example output shown during the training is finished rates on inputs of the same lengths as the ones used during train- by a special “End Of Output” symbol, which the model ing. The lower table shows the error rates on input sequences learns to predict. It forces the model to learn not only the 2 to 4 times longer than the ones encountered during training. output symbols, but also the length of the correct output. LSTM+A denotes an LSTM with the standard attention mecha- nism. Each error rate is a percentage of output sequences, which We compare our model with 2 strong baseline mod- contained at least one incorrectly predicted bit. els: encoder-decoder LSTM (Sutskever et al., 2014) and test error LSTM LSTM+A LSTM+HAM encoder-decoder LSTM with attention (denoted LSTM+A) Reverse 73% 0% 0% (Bahdanau et al., 2014). The number of the LSTM cells Search 62% 0.04% 0.12% in the baselines was chosen in such a way, that they have Merge 88% 16% 0% more parameters than the biggest of our models. We also Sort 99% 25% 0.04% use random search to select an optimal learning rate and Add 39% 0% 0% some other parameters for the baselines and train them us- 2-4x longer inputs LSTM LSTM+A LSTM+HAM ing the same curriculum scheme as LSTM+HAM. Reverse 100% 100% 0% Search 89% 0.52% 1.68% The results are presented in Table 1. Not only, does Merge 100% 100% 2.48% LSTM+HAM solve all the problems almost perfectly, but Sort 100% 100% 0.24% it also generalizes very well to much longer inputs on all Add 100% 100% 100% problems except Add. Recall that for the generalization Complexity Θ(1) Θ(n) Θ(log n) tests we used a HAM memory of a different size than the ones used during the training, what shows that HAM gen- 3 4.3. Raw HAM Notice that we earlier assumed for the sake of simplicity that the input sequences consist of binary vectors and in this task the In this section, we evaluate “raw” HAM module (without priorities are real values. It does not however require any change the LSTM controller) to see if it can act as a drop-in re- of our model. We decided to use real priorities in this task in order to diversify our set of problems. placement for 3 classic data structures: a stack, a FIFO 4 Stability means that pairs with equal keys should be ordered queue and a priority queue. For each task, the network is accordingly to their order in the input sequence. given a sequence of PUSH and POP operations in an on-
7 . Learning Efficient Algorithms with Hierarchical Attentive Memory line manner: at timestep t the network sees only the t-th Table 2. Results of experiments with the raw version of HAM operation to perform xt . This is a more realistic scenario (without the LSTM controller). Error rates are measured as a per- for data structures usage as it prevents the network from centage of operation sequences in which at least one POP query cheating by peeking into the future. was not answered correctly. Raw HAM module differs from the LSTM+HAM model Task Test Error Generalization from Sec. 3 in the following way: Error Stack 0% 0% Queue 0% 0% • The HAM memory is initialized with zeros. PriorityQueue 0.08% 0.2% • The t-th output symbol yt is computed using an MLP from the value in the accessed leaf ha . 4.4. Analysis • Notice that in the LSTM+HAM model, hLSTM acted In this section, we present some insights into the algorithms as a kind of “query” or “command” guiding the be- learned by the LSTM+HAM model, by investigating the haviour of HAM. We will now use the values xt in- the hidden representations he learned for a variant of the stead. Therefore, at the t-th timestep we use xt in- problem Sort in which we sort 4-bit vectors lexicograph- stead of hLSTM whenever hLSTM was used in the orig- ically5 . For demonstration purposes, we use a small tree inal model, e.g. during the attention phase (Fig. 3) with n = 8 leaves and d = 6. we use p = SEARCH(hc , xt ) instead of p = The trained network performs sorting perfectly. It attends SEARCH(hc , hLSTM). to the leaves in the order corresponding to the order of the sorted input values, i.e. at every timestep HAM attends to We evaluate raw HAM on the following tasks: the leaf corresponding to the smallest input value among the leaves, which have not been attended so far. Stack: The “PUSH x” operation places the element x It would be interesting to exactly understand the algorithm (a 5-bit vector) on top of the stack, and the “POP” returns used by the network to perform this operation. A natural the last added element and removes it from the stack. solution to this problem would be to store in each hidden node e the smallest input value among the (unattended so Queue: The “PUSH x” operation places the element x (a far) leaves below e together with the information whether 5-bit vector) at the end of the queue and the “POP” returns the smallest value is in the right or the left subtree under e. the oldest element and removes it from the queue. We present two timesteps of our model together with some insights into the algorithm used by the network in Fig.6. PriorityQueue: The “PUSH x p” operations adds the element x with priority p to the queue. The “POP” 5. Comparison to other models operation returns the value with the highest priority and re- move it from the queue. Both x and p are represented as Comparing neural networks able to learn algorithms is dif- 5-bit vectors and priorities are compared lexicographically. ficult for a few reasons. First of all, there are no well- To avoid ties we assume that all elements have different established benchmark problems for this area. Secondly, priorities. the difficulty of a problem often depends on the way in- puts and outputs are encoded. For example, the difficulty Model was trained with the memory of size up to n = of the problem of adding long binary numbers depends on 32 with operation sequences of length n. Sequences of whether the numbers are aligned (i.e. the i-th bit of the PUSH/POP actions for training were selected randomly. second number is “under” the i-th bit of the first number) The t-th operation out of n operations in the sequence was or written next to each other (e.g. 10011+10101). More- POP with probability nt and PUSH otherwise. To test gen- over, we could compare error rates on inputs from the same eralization, we report the error rates with the memory of distribution as the ones seen during the training or com- size 4n on sequences of operations of length 4n. pare error rates on inputs longer than the ones seen dur- The results presented in Table 2 shows that HAM sim- ing the training to see if the model “really learned the al- ulates a stack and a queue perfectly with no errors 5 whatsoever even for memory 4 times bigger. For the In the problem Sort considered in the experimental results, there are separate keys and values, which forces the model to learn PriorityQueue task, the model generalizes almost per- stable sorting. Here, for the sake of simplicity, we consider the fectly to large memory, with errors only in 0.2% of output simplified version of the problem and do not use separate keys sequences. and values.
8 . Learning Efficient Algorithms with Hierarchical Attentive Memory (a) The first timestep (b) The second timestep Figure 6. This figure shows two timesteps of the model. The LSTM controller is not presented to simplify the exposition. The input sequence is presented on the left, below the tree: x1 = 0000, x2 = 1110, x3 = 1101 and so on. The 2x3 grids in the nodes of the tree represent the values he ∈ R6 . White cells correspond to value 0 and non-white cells correspond to values > 0. The lower-rightmost cells are presented in pink, because we managed to decipher the meaning of this coordinate for the inner nodes. This coordinate in the node e denotes whether the minimum in the subtree (among the values unattended so far) is in the right or left subtree of e. Value greater than 0 (pink in the picture) means that the minimum is in the right subtree and therefore we should go right while visiting this node in the attention phase. In the first timestep the leftmost leaf (corresponding to the input 0000) is accessed. Notice that the last coordinates (shown in pink) are updated appropriately, e.g. the smallest unattended value at the beginning of the second timestep is 0101, which corresponds to the 6-th leaf. It is in the right subtree under the root and accordingly the last coordinate in the hidden value stored in the root is high (i.e. pink in the figure). gorithm”. Furthermore, different models scale differently chine (Kurach et al., 2015), and Queue-Augmented LSTM with the memory size, which makes direct comparison of (Grefenstette et al., 2015). However, the first three models error rates less meaningful. have been only successful on relatively simple tasks. The last model was successful on some synthetic tasks from the As far as we know, our model is the first one which is domain of Natural Language Processing, which are very able to learn a sorting algorithm from pure input-output different from the tasks we tested our model on, so we can examples. In (Reed & de Freitas, 2015) it is shown that not directly compare the two models. an LSTM is able to learn to sort short sequences, but it fails to generalize to inputs longer than the ones seen dur- Finally, we do not claim that our model is superior to ing the training. It is quite clear that an LSTM can not the all other ones, e.g. Neural Turing Machines (NTM) learn a “real” sorting algorithm, because it uses a bounded (Graves et al., 2014). We believe that both memory mech- memory independent of the length of the input. The Neu- anisms are complementary: NTM memory has a built-in ral Programmer-Interpreter (Reed & de Freitas, 2015) is a associative map functionality, which may be difficult to neural network architecture, which is able to learn bubble achieve in HAM. On the other hand, HAM performs bet- sort, but it requires strong supervision in the form of execu- ter in tasks like sorting due to a built-in bias towards op- tion traces. In comparison, our model can be trained from erating on intervals of memory cells. Moreover, HAM al- pure input-output examples, which is crucial if we want to lows much more efficient memory access than NTM. It is use it to solve problems for which we do not know any al- also quite possible that a machine able to learn algorithms gorithms. should use many different types of memory in the same way as human brain stores a piece of information differ- An important feature of neural memories is their ef- ently depending on its type and how long it should be stored ficiency. Our HAM module in comparison to many (Berntson & Cacioppo, 2009). other recently proposed solutions is effective and al- lows to access the memory in Θ(log(n)) complexity. 6. Conclusions In the context of learning algorithms it may sound sur- prising that among all the architectures mentioned in We presented a new memory architecture for neural net- Sec. 2 the only ones, which can copy a sequence of works called Hierarchical Attentive Memory. Its crucial length n without Θ(n2 ) operations are: Reinforcement- property is that it scales well with the memory size — the Learning NTM (Zaremba & Sutskever, 2015), the model memory access requires only Θ(log n) operations. This from (Zaremba et al., 2015), Neural Random-Access Ma- complexity is achieved by using a new attention mecha-
9 . Learning Efficient Algorithms with Hierarchical Attentive Memory nism based on a binary tree. The novel attention mecha- 2015). nism is not only faster than the standard one used in Deep This version of the model is fully differentiable and there- Learning, but it also facilities learning algorithms due to fore it can be trained using end-to-end backpropagation on the embedded tree structure. the log-probability of producing the correct output. We ob- We showed that an LSTM augmented with HAM can learn served that training DHAM is slightly easier than the RE- a number of algorithms like merging, sorting or binary INFORCE version. However, DHAM does not generalize searching from pure input-output examples. In particular, as well as HAM to larger memory sizes. it is the first neural architecture able to learn a sorting algo- rithm and generalize well to sequences much longer than References the ones seen during the training. Bahdanau, Dzmitry, Cho, Kyunghyun, and Bengio, We believe that some concepts used in HAM, namely the Yoshua. Neural machine translation by jointly learning novel attention mechanism and the idea of aggregating in- to align and translate. arXiv preprint arXiv:1409.0473, formation through a binary tree may find applications in 2014. Deep Learning outside of the problem of designing neural memories. Berntson, G.G. and Cacioppo, J.T. Handbook of Neuro- science for the Behavioral Sciences. Number v. 1 in Acknowledgements Handbook of Neuroscience for the Behavioral Sciences. Wiley, 2009. ISBN 9780470083567. We would like to thank Nando de Freitas, Alexander Graves, Serkan Cabi, Misha Denil and Jonathan Hunt for Graves, Alex, Wayne, Greg, and Danihelka, Ivo. Neural helpful comments and discussions. turing machines. arXiv preprint arXiv:1410.5401, 2014. Grefenstette, Edward, Hermann, Karl Moritz, Suleyman, A. Using soft attention Mustafa, and Blunsom, Phil. Learning to transduce with One of the open questions in the area of designing neu- unbounded memory. In Advances in Neural Information ral networks with attention mechanisms is whether to use Processing Systems, pp. 1819–1827, 2015. a soft or hard attention. The model described in the pa- Hochreiter, Sepp and Schmidhuber, J¨urgen. Long short- per belongs to the latter class of attention mechanisms as it term memory. Neural computation, 9(8):1735–1780, makes hard, stochastic choices. The other solution would 1997. be to use a soft, differentiable mechanism, which attends to a linear combination of the potential attention targets and Joulin, Armand and Mikolov, Tomas. Inferring algorith- do not involve any sampling. The main advantage of such mic patterns with stack-augmented recurrent nets. arXiv models is that their gradients can be computed exactly. preprint arXiv:1503.01007, 2015. We now describe how to modify the model to make it Kaiser, Łukasz and Sutskever, Ilya. Neural gpus learn al- fully differentiable (”DHAM”). Recall that in the origi- gorithms. arXiv preprint arXiv:1511.08228, 2015. nal model the leaf which is attended at every timestep is sampled stochastically. Instead of that, we will now at ev- Kalchbrenner, Nal, Danihelka, Ivo, and Graves, Alex. ery timestep compute for every leaf e the probability p(e) Grid long short-term memory. arXiv preprint that this leaf would be attended if we used the stochastic arXiv:1507.01526, 2015. procedure described in Fig. 3. The value p(e) can be com- Kingma, Diederik and Ba, Jimmy. Adam: A puted by multiplying the probabilities of going in the right method for stochastic optimization. arXiv preprint direction from all the nodes on the path from the root to e. arXiv:1412.6980, 2014. As the input for the LSTM we then use the value Kurach, Karol, Andrychowicz, Marcin, and Sutskever, e∈L p(e) · he . During the write phase, we update the Ilya. Neural random-access machines. arXiv preprint values of all the leaves using the formula he := p(e) · arXiv:1511.06392, 2015. WRITE(he , hROOT ) + (1 − p(e)) · he . Then, in the up- date phase we update the values of all the inner nodes, so Li, Yujia, Tarlow, Daniel, Brockschmidt, Marc, and Zemel, that the equation he = JOIN(hl(e) , hr(e) ) is satisfied for Richard. Gated graph sequence neural networks. arXiv each inner node e. Notice that one timestep of the soft ver- preprint arXiv:1511.05493, 2015. sion of the model takes time Θ(n) as we have to update the values of all the nodes in the tree. Our model may be seen Morin, Frederic and Bengio, Yoshua. Hierarchical proba- as a special case of Gated Graph Neural Network (Li et al., bilistic neural network language model. In Aistats, vol- ume 5, pp. 246–252. Citeseer, 2005.
10 . Learning Efficient Algorithms with Hierarchical Attentive Memory Nair, Vinod and Hinton, Geoffrey E. Rectified linear units Advances in neural information processing systems, pp. improve restricted boltzmann machines. In Proceedings 3104–3112, 2014. of the 27th International Conference on Machine Learn- ing (ICML-10), pp. 807–814, 2010. Vinyals, Oriol, Toshev, Alexander, Bengio, Samy, and Er- han, Dumitru. Show and tell: A neural image caption Pascanu, Razvan, Mikolov, Tomas, and Bengio, Yoshua. generator. arXiv preprint arXiv:1411.4555, 2014. Understanding the exploding gradient problem. Comput- ing Research Repository (CoRR) abs/1211.5063, 2012. Vinyals, Oriol, Fortunato, Meire, and Jaitly, Navdeep. Pointer networks. arXiv preprint arXiv:1506.03134, Reed, Scott and de Freitas, Nando. Neural programmer- 2015. interpreters. arXiv preprint arXiv:1511.06279, 2015. Weston, Jason, Chopra, Sumit, and Bordes, Antoine. Mem- Schulman, John, Heess, Nicolas, Weber, Theophane, and ory networks. arXiv preprint arXiv:1410.3916, 2014. Abbeel, Pieter. Gradient estimation using stochastic computation graphs. In Advances in Neural Information Williams, Ronald J. Simple statistical gradient-following Processing Systems, pp. 3510–3522, 2015. algorithms for connectionist reinforcement learning. Srivastava, Rupesh Kumar, Greff, Klaus, and Schmid- Machine learning, 8(3-4):229–256, 1992. huber, J¨urgen. Highway networks. arXiv preprint Zaremba, Wojciech and Sutskever, Ilya. Reinforce- arXiv:1505.00387, 2015. ment learning neural turing machines. arXiv preprint Sukhbaatar, Sainbayar, Szlam, Arthur, Weston, Jason, and arXiv:1505.00521, 2015. Fergus, Rob. End-to-end memory networks. arXiv Zaremba, Wojciech, Mikolov, Tomas, Joulin, Armand, and preprint arXiv:1503.08895, 2015. Fergus, Rob. Learning simple algorithms from exam- Sutskever, Ilya, Vinyals, Oriol, and Le, Quoc VV. Se- ples. arXiv preprint arXiv:1511.07275, 2015. quence to sequence learning with neural networks. In