232

Uncovering hidden geometry in Transformers via disentangling position and context

Abstract

Transformers are widely used to extract complex semantic meanings from input tokens, yet they usually operate as black-box models. In this paper, we present a simple yet informative decomposition of hidden states (or embeddings) of trained transformers into interpretable components. For any layer, embedding vectors of input sequence samples are represented by a tensor hRC×T×d\boldsymbol{h} \in \mathbb{R}^{C \times T \times d}. Given embedding vector hc,tRd\boldsymbol{h}_{c,t} \in \mathbb{R}^d at sequence position tTt \le T in a sequence (or context) cCc \le C, extracting the mean effects yields the decomposition \[ \boldsymbol{h}_{c,t} = \boldsymbol{\mu} + \mathbf{pos}_t + \mathbf{ctx}_c + \mathbf{resid}_{c,t} \] where μ\boldsymbol{\mu} is the global mean vector, post\mathbf{pos}_t and ctxc\mathbf{ctx}_c are the mean vectors across contexts and across positions respectively, and residc,t\mathbf{resid}_{c,t} is the residual vector. For popular transformer architectures and diverse text datasets, empirically we find pervasive mathematical structure: (1) (post)t(\mathbf{pos}_t)_{t} forms a low-dimensional, continuous, and often spiral shape across layers, (2) (ctxc)c(\mathbf{ctx}_c)_c shows clear cluster structure that falls into context topics, and (3) (post)t(\mathbf{pos}_t)_{t} and (ctxc)c(\mathbf{ctx}_c)_c are mutually incoherent -- namely post\mathbf{pos}_t is almost orthogonal to ctxc\mathbf{ctx}_c -- which is canonical in compressed sensing and dictionary learning. This decomposition offers structural insights about input formats in in-context learning (especially for induction heads) and in arithmetic tasks.

View on arXiv
Comments on this paper