时间:2026-03-12 13:39 所属分类:科技技术论文 点击次数:
Abstract The effectiveness of neural processes (NPs) in modelling posterior pred
Abstract The effectiveness of neural processes (NPs) in modelling posterior prediction maps—the mapping from data to posterior predictive distributions—has significantly improved since their inception. This improvement can be attributed to two principal factors: (1) advancements in the architecture of permutation invariant set functions, which are intrinsic to all NPs; and (2) leveraging symmetries present in the true posterior predictive map, which are problem dependent. Transformers are a notable development in permutation invariant set functions, and their utility within NPs has been demonstrated through the family of models we refer to as transformer neural processes (TNPs). Despite significant interest in TNPs, little attention has been given to incorporating symmetries. Notably, the posterior prediction maps for data that are stationary—a common assumption in spatiotemporal modelling—exhibit translation equivariance. In this paper, we introduce of a new family of translation equivariant TNPs (TE-TNPs) that incorporate translation equivariance. Through an extensive range of experiments on synthetic and real-world spatio-temporal data, we demonstrate the effectiveness of TE-TNPs relative to their nontranslation-equivariant counterparts and other NP baselines. 1. Introduction Transformers have emerged as an immensely effective architecture for natural language processing and computer vision tasks (Vaswani et al., 2017; Dosovitskiy et al., 2020). They have become the backbone for many state-of-the-art 1Department of Engineering, University of Cambridge, Cambridge, UK 2Vector Institute, University of Toronto, Toronto, Canada 3Microsoft Research AI for Science, Cambridge, UK. Correspondence to: Matthew Ashman <mca39@cam.ac.uk>. Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). models—such ChatGPT (Achiam et al., 2023) and DALL-E (Betker et al., 2023)—owing to their ability to learn complex dependencies amongst input data. More generally, transformers can be understood as permutation equivariant set functions. This abstraction has led to the deployment of transformers in domains beyond that of sequence modelling, including particle physics, molecular modelling, climate science, and Bayesian inference (Lee et al., 2019; Fuchs et al., 2020; Müller et al., 2021). NPs (Garnelo et al., 2018a;b) are a broad family of metalearning models which learn the mapping from sets of observed datapoints to predictive stochastic processes (Foong et al., 2020). They are straightforward to train, handle offthe-grid data and missing observations with ease, and can be easily adapted for different data modalities. This flexibility makes them an attractive choice for a wide variety of problem domains, including spatio-temporal modelling, healthcare, and few-shot learning (Jha et al., 2022). Exchangeability in the predictive distribution with respect to the context set is achieved through the use of permutation invariant set functions, which, in NPs, map from the sets of observations to some representation space. Given the utility of transformers as set functions, it is natural to consider their use within NPs. This gives rise to TNPs. The family of TNPs include the attentive NP (ANP) (Kim et al., 2019), diagonal TNP (TNP-D), autoregressive TNP (TNP-AR), and non-diagonal TNP (TNP-ND) (Nguyen & Grover, 2022), and the latent-bottlenecked ANP (LBANP) (Kim et al., 2019). Despite a significant amount of interest in TNPs from the research community, there are certain properties that we may wish our model to possess that have not yet been addressed. In particular, for spatio-temporal problems the data is often roughly stationary, in which case it is desirable to equip our model with translation equivariance: if the data are translated in space or time, then the predictions of our model should be translated correspondingly. Although translation equivariance has been incorporated into other families of NP models, such as the convolutional conditional NP (ConvCNP) (Gordon et al., 2019) and relational CNP (RCNP) (Huang et al., 2023), it is yet to be incorporated into the TNP. The key ingredient to achieving this is to establish effective translation equivariant attention layers that can be used in place of the standard attention layers 1 arXiv:2406.12409v1 [stat.ML] 18 Jun 2024Translation Equivariant Transformer Neural Processes within the transformer encoder. In this paper, we develop the TE-TNP. Our contributions are as follows: 1. We develop an effective method for incorporating translation equivariance into the attention mechanism of transformers, developing the translation equivariant multi-head self attention (TE-MHSA) and translation equivariant multi-head cross attention (TE-MHCA) operations. These operations replace standard MHSA and MHCA operations within transformer encoders to obtain a new family of translation equivariant TNPs. 2. We use pseudo-tokens to reduce the quadratic computational complexity of TE-TNPs, developing translation equivariant PT-TNPs (TE-PT-TNPs). 3. We demonstrate the efficacy of TE-TNPs relative to existing NPs—including the ConvCNP and the RCNP— on a number of synthetic and real-world spatio-temporal modelling problems. 2. Background Throughout this section, we will use the following notation. Let X = R Dx , Y = R Dy denote the input and output spaces, and let (x, y) ∈ X × Y denote an input–output pair. Let S = S ∞ N=0(X × Y) N be a collection of all finite data sets, which includes the empty set ∅, the data set containing no data points. We denote a context and target set with Dc, Dt ∈ S, where |Dc| = Nc, |Dt| = Nt. Let Xc ∈ R Nc×Dx , Yc ∈ R Nc×Dy be the inputs and corresponding outputs of Dc, with Xt ∈ R Nt×Dx , Yt ∈ R Nt×Dy defined analogously. We denote a single task as ξ = (Dc, Dt) = ((Xc, Yc),(Xt, Yt)). Let P(X ) denote the collection of stochastic processes on X . 2.1. Neural Processes NPs (Garnelo et al., 2018a;b) aim to learn the mapping from context sets Dc to ground truth posterior distributions over the target outputs, Dc 7→ p(Yt|Xt, Dc), using metalearning. This mapping is known as the posterior prediction map πP : S → P(X ), where P denotes the ground truth stochastic process over functions mapping from X to Y. Common to all NP architectures is an encoder and decoder. The encoder maps from Dc and Xt to some representation, e(Dc, Xt). 1 The decoder takes as input the representation and target inputs Xt and outputs d(Xt, e(Dc, Xt)), which are the parameters of the predictive distribution over the target outputs Yt: p(Yt|Xt, Dc) = p(Yt|d(Xt, e(Dc, Xt))). An important requirement of the predictive distribution is permutation invariance with respect to the elements of Dc. 1 In many NP architectures, including the original conditional NP (CNP) and NP, the representation does not depend on the target inputs Xt. We shall focus on CNPs (Garnelo et al., 2018a), which factorise the predictive distribution as p(Yt|Xt, Dc) = Q Nt n=1 p(yt,n|d(xt,n, e(Dc, xt,n))). CNPs are trained by maximising the posterior predictive likelihood: LML =Ep(ξ) h P Nt n=1 log p(yt,n|d(xt,n, e(Dc, xt,n)))i . (1) Here, the expectation is taken with respect to the distribution of tasks p(ξ). As shown in Foong et al. (2020), the global maximum is achieved if and only if the model recovers the ground-truth posterior prediction map. When training a CNP, we often approximate the expectation with an average over the finite number of tasks available. 2.2. Transformers A useful perspective is to understand transformers as a permutation equivariant set function f. 2 They take as input a set of N tokens, Z 0 ∈ R N×Dz , output a set of N tokens of the same cardinality: f : (R Dz ) N → (R Dz ) N . If the input set is permuted, then the output set is permuted accordingly: f(z1, . . . , zN )n = f(zσ(1), . . . , zσ(N))σ(n) for all permutations σ ∈ S N of N elements. At the core of each layer of the transformer architecture is the multi-head self attention (MHSA) operation (Vaswani et al., 2017). Let Z ℓ ∈ R N×Dz denote the input set to the ℓ-th MHSA operation. The MHSA operation updates the n th token z ℓ n as ˜z ℓ n = catn P N m=1 αh ℓ (z ℓ n , z ℓ m)z ℓ m TWV,h ℓ o Hℓ h=1 WO ℓ (2) where cat denotes the concatenation operation across the last dimension. Here, WV,h ℓ ∈ R Dz×DV and WO ℓ ∈ R HℓDV ×Dz are the value and projection weight matrices, where Hℓ denotes the number of ‘heads’ in layer ℓ. Note that permutation equivariance is achieved through the permutation invariant summation operator. As this is the only mechanism through which the tokens interact with each other, permutation equivariance for the overall model is ensured. The attention mechanism, αh ℓ , is implemented as α ℓ h (z ℓ n , z ℓ m) = e z ℓ n TWQ,h ℓ [WK,h ℓ ] T z ℓ m P N m=1 e zℓ n TWQ,h ℓ [WK,h ℓ ] T zℓm (3) where WQ,h ℓ ∈ R Dz×DQK and WK,h ℓ ∈ R Dz×DQK are the query and key weight matrices. The softmax-normalisation ensures that P N m=1 αh ℓ (z ℓ n , z ℓ m) = 1 ∀n, h, ℓ. Often, conditional independencies amongst the set of tokens—in 2Note that not all permutation equivariant set functions can be represented by transformers. For example, the family of informers (Garnelo & Czarnecki, 2023) cannot be represented by transformers, yet are permutation equivariant set functions. However, transformers are universal approximators of permutation equivariant set functions (Lee et al., 2019; Wagstaff et al., 2022). 2Translation Equivariant Transformer Neural Processes the sense that the set {z ℓ n} ℓ ℓ = =1 L do not depend on the set {z ℓ m} ℓ ℓ = =1 L given some other set of tokens for some n, m ∈ {1, . . . , N}—are desirable. Whilst this is typically achieved through masking, if the same set of tokens are conditioned on for every n, then it is more computationally efficient to use multi-head cross attention (MHCA) operations together with MHSA operations than it is to directly compute Equation 2. The MHCA operation updates the n th token z ℓ n using the set of tokens {zˆ ℓ m}M m=1 as ˜z ℓ n = catn P M m=1 αh ℓ (z ℓ n , zˆ ℓ m)zˆ ℓ m TWV,h ℓ o Hℓ h=1 WO ℓ . (4) Note that all tokens updated in this manner are conditionally independent of each other given {zˆ ℓ m}M m=1. We discuss this in more detail in Appendix B. MHCA operations are at the core of the pseudo-token-based transformers such as the perceiver (Jaegle et al., 2021) and induced set transformer (IST) (Lee et al., 2019). We describe these differences in the following section. MHSA and MHCA operations are used in combination with layer-normalisation operations and pointwise MLPs to obtain MHSA and MHCA blocks. Unless stated otherwise, we shall adopt the order used by Vaswani et al. (2017). 2.3. Pseudo-Token-Based Transformers Pseudo-token based transformers reduce the quadratic computational complexity of the standard transformer through the use of pseudo-tokens. Concretely, let U ∈ RM×Dz denote an initial set of M ≪ N tokens we call pseudotokens. There are two established methods for incorporating information about the set of observed tokens (Z) into these pseudo-tokens in a computationally efficient manner: the perceiver-style approach of Jaegle et al. (2021) and the IST style approach of Lee et al. (2019). The perceiverstyle approach iterates between applying MHCA(Uℓ , Z ℓ ) and MHSA(Uℓ ), outputting a set of M pseudo-tokens, and has a computational complexity of O (MN) at each layer. The IST-style approach iterates between applying MHCA(Uℓ , Z ℓ ) and MHCA(Z ℓ , Uℓ ), outputting a set of N tokens and M pseudo-tokens, and also has a computational complexity of O (MN) at each layer. We provide illustrations these differences Appendix C. 2.4. Transformer Neural Processes Given the utility of transformers as set functions, it is natural to consider their use in the encoder of a NP—we describe this family of NPs as TNPs. Let Z 0 c ∈ R Nc×D denote the initial set-of-token representation of each input-output pair (xc,n, yc,n) ∈ Dc, and Z 0 t,n ∈ R Nt×D denote the initial set-of-token representation of each input xt,n ∈ Xt. The encoding e(Dc, Xt) of TNPs is is achieved by passing the union of initial context and target tokens, Z 0 = [Z 0 c , Z 0 t ], through a transformer-style architecture, and keeping only the output tokens corresponding to the target inputs, Z L t . The specific transformer-style architecture is unique to each TNP variant. However, they generally consist of MHSA operations acting on the context tokens and MHCA operations acting to update the target tokens, given the context tokens.3 The combination of MHSA and MHCA operations is a permutation invariant function with respect to the context tokens. We provide an illustration of this in Figure 1a. Enforcing these conditional independencies ensures that the final target token z L t,n depends only on Dc and xt,n, i.e. [e(Dc, Xt)]n = e(Dc, xt,n). This is required for the factorisation of the predictive distribution p(Yt|Xt, Dc) = Q Nt n=1 p(yt,n|d(xt,n, e(Dc, xt,n))). We denote pseudo-token TNPs (PT-TNPs) as the family of TNPs which use pseudo-token based transformers. Currently, this family is restricted to the LBANP, which uses a perceiver-style architecture; however, it is straightforward to use an IST-style architecture instead. 2.5. Translation Equivariance Here, we provide new theoretical results which show the importance of translation equivariance as an inductive bias in NPs. In particular, we first show that if, and only if, the ground-truth stochastic process is stationary, the corresponding predictive map is translation equivariant (Theorem 2.1). Second, we show the importance of translation equivariance in the ability of our models to generalise to settings outside of the training distribution (Theorem 2.2), for which Figure 3 provides some intuition. Let Tτ denote a translation by τ ∈ R Dx . For a data set D ∈ S, TτD ∈ S translates the data set by adding τ to all inputs. For a function f : X → Z, Tτ f translates f by producing a new function X → Z such that Tτ f(x) = f(x − τ ) for all x ∈ R Dx . For a stochastic process µ ∈ P(X ), Tτ (µ) denotes the pushforward measure of pushing µ through Tτ . A prediction map π is a mapping π : S → P(X ) from data sets S to stochastic processes P(X ). Prediction maps are mathematical models of neural processes. Say that a prediction map π is translation equivariant if Tτ ◦ π = π ◦ Tτ for all translations τ ∈ R Dx . The ground-truth stochastic process P is stationary if and only if the prediction map πP is translation equivariant. Foong et al. (2020) provide a simple proof of the “only if”- direction. We provide a rigorous proof in both directions. Consider D ∈ S. Formally define πP (D) by integrating P against a density πP ′ (D) that depends on D, so dπP (D) = πP ′ (D) dP. 4 Assume that πP ′ (∅) ∝ 1, so πP (∅) = P. Say 3As discussed in Section 2.2, this is often implemented as a single MHSA operation with masking operating. 4 Intuitively, πP ′ (D)(f) = p(D|f)/p(D), so πP ′ (D)(f) is the 3Translation Equivariant Transformer Neural Processes Dc Xt MLP(Xc, Yc) MLP(Xt) Z 0 c Z 0 t MHSA(Z 0 c) MHSA(Zc L−1 ) MHCA(Z 0 t , Z 1 c) MHCA(Z L t −1 , Z L c ) e(Dc, Xt) (a) TNP. te-MHSA(Z 0 c, X0 c) te-MHSA(Z L c −1 , XL c −1 ) te-MHCA(Z 0 t , Z 1 c, X0 t , X0 c) te-MHCA(Z L t −1 , Z L c , XL t −1 , XL c −1 ) e(Dc, Xt) Z 0 c Z 0 t MLP(Yc) Yc X0 c = Xc X0 t = Xt (b) TE-TNP. Figure 1. Block diagrams illustrating the TNP and TE-TNP encoder architectures. For both models, we pass individual datapoints through pointwise MLPs to obtain the initial token representations, Z 0 c and Z 0 t . These are then passed through multiple attention layers, with the context tokens interacting with the target tokens through cross-attention. The output of the encoder depends on Dc and Xt. The TE-TNP encoder updates the input locations at each layer, in addition to the tokens. Figure 2. Average log-likelihood (↑) on the test datasets for the synthetic 1-D regression experiment. ∆ denotes the amount by which the range from which the context and target inputs and sampled from is shifted at test time. Standard errors are shown. that πP ′ is translation invariant if, for all D ∈ S and τ ∈ X , πP ′ (TτD) ◦ Tτ = πP ′ (D) P–almost surely.5 Theorem 2.1. (1) The ground-truth stochastic process P is stationary and πP ′ is translation invariant if and only if (2) πP is translation equivariant. See Appendix E for the proof. If the ground-truth stochastic process is stationary, it is helpful to build translation equivariance into the neural process: this greatly reduces the model space to search over, which can significantly improve data efficiency (Foong et al., 2020). In addition, it is possible to show that translation equivariant NPs generalise spatially. We formalise this in the following theorem, which we present in the one-dimensional setting (Dx = 1) for notational simplicity, and we provide an illustration of these ideas in Figure 3. Definitions for theorem. For a stochastic process f ∼ µ with µ ∈ P(X ), for every x ∈ R N , denote the distribution of (f(x1), . . . , f(xN )) by Pxµ. We now define the notion of the receptive field. For two vectors of inputs x1 ∈ R N1 , x2 ∈ R N2 , and R > 0, let x1|x2,R be the subvector of x1 with inputs at most distance R away from any input in x2. Similarly, for a data set D = (x, y) ∈ S, let D|x2,R ∈ S be the subset of data points of D with inputs at most distance R away from any input in x2. With these definitions, say that a stochastic process f ∼ µ with µ ∈ P(X ) has receptive field R > 0 if, for all N1, N2 ∈ N, x1 ∈ R N1 , and x2 ∈ R N2 , f(x2) | f(x1) = d f(x2) | f(x1|x2, 1 2 R). Intuitively, f only modelling assumption that specifies how observations are generated from the ground-truth stochastic process. A simple example is πP ′ (D)(f) ∝ Q (x,y)∈D N (y | f(x), σ2 ), which adds independent Gaussian noise with variance σ 2 . 5 For example, the usual Gaussian likelihood is translation invariant: N (y | (Tτ f)(x + τ ), σ2 ) = N (y | f(τ ), σ2 ). has local dependencies. Moreover, say that a prediction map π : S → P(X ) has receptive field R > 0 if, for all D ∈ S, N ∈ N, and x ∈ R N , Pxπ(D) = Pxπ(D|x, 1 2 R). Intuitively, predictions by the neural process π are only influenced by context points at most distance 2 1R away.6 Theorem 2.2. Let π1, π2 : S → P(X ) be translation equivariant prediction maps with receptive field R > 0. Assume that, for all D ∈ S, π1(D) and π2(D) also have receptive field R > 0. Let ϵ > 0 and fix N ∈ N. Assume that, for all x ∈ S N n=1[0, 2R] n and D ∈ S ∩ S ∞ n=0 ([0, 2R] × R) n , KL [Pxπ1(D)||Pxπ2(D))] ≤ ϵ. (5) Then, for all M > 0, x ∈ S N n=1[0, M] n, and D ∈ S ∩ S ∞ n=0 ([0, M] × R) n , KL [Pxπ1(D)||Pxπ2(D))] ≤ ⌈2M/R⌉ϵ. (6) See Appendix E for the proof. The notion of receptive field is natural to CNNs and corresponds to the usual notion of the receptive field. The notion is also inherent to transformers that adopt sliding window attention: the size of the window multiplied by the number of transformer layers gives the receptive field of the model. Intuitively, this theorem states that, if (a) the ground-truth stochastic process and our model are translation equivariant and (b) everything has receptive field size R > 0, then, whenever our model is accurate on [0, 2R], it is also accurate on any bigger interval [0, M]. Note that this theorem accounts for dependencies between target points, so it also applies to latent-variable neural processes (Garnelo et al., 2018b) and Gaussian neural 6 For example, if f is a Gaussian process with a kernel compactly supported on [− 1 2R, 1 2R], then the mapping D 7→ p(f | D) is a prediction map which (a) has receptive field R and (b) maps to stochastic processes with receptive field R. 4 ... ...... ...Translation Equivariant Transformer Neural Processes xc xt R R (a) For a model with receptive field R > 0, a context point at xc influences predictions at target inputs only limitedly far away. Conversely, a prediction at a target input xt is influenced by context points only limitedly far away. xc xt training range R R TE (b) If a model is translation equivariant, then all context points and targets inputs can simultaneously be shifted left or right without changing the output of the model. Intuitively, this means that triangles in the figures can just be “shifted left or right”. Figure 3. Translation equivariance in combination with a limited receptive field (see (a)) can help generalisation performance. Consider a translation equivariant (TE) model which performs well within a training range (see (b)). Consider a prediction for a target input outside the training range (right triangle in (b)). If the model has receptive field R > 0 and the training range is bigger than R, then TE can be used to “shift that prediction back into the training range” (see (b)). Since the model performs well within the training range, the model also performs well for the target input outside the training range. processes (Bruinsma et al., 2021; Markou et al., 2022). In practice, we do not explicitly equip our transformer neural processes with a fixed receptive field size by using sliding window attention, although this would certainly be possible. The main purpose of this theorem is to elucidate the underlying principles that enable translation equivariant neural processes to generalise. 3. Translation Equivariant Transformer Neural Processes Let Z ℓ ∈ R N×Dz and Z˜ℓ ∈ R N×Dz denote the inputs and outputs at layer ℓ, respectively. To achieve translation equivariance, we must ensure that each token z ℓ n is translation invariant with respect to the inputs, which can be achieved through dependency solely on the pairwise distances xi−xj (see Appendix D). We choose to let our initial token embeddings z 0 n depend solely on the corresponding output yn, and introduce dependency on the pairwise distances through the attention mechanism. For permutation and translation equivariant operations, we choose updates of the form z˜ ℓ n = ϕ