Multihead self-attention in cortico-thalamic circuits
Both biological cortico-thalamic networks and artificial transformer networks use canonical computations to perform a wide range of cognitive tasks. In this work, we propose that the structure of cortico-thalamic circuits is well suited to realize a computation analogous to multihead self-attention, the main algorithmic innovation of transformer networks. We assign distinct computational roles to superficial and deep pyramidal cells of the cortex: while superficial pyramidal cells maintain a key-value memory, deep pyramidal cells encode the current query, gain-modulated by the key-value memory in the superficial layer. We show that the structure of this computation matches the fine-grained structure of core and matrix projections from the thalamus to the cortex. We then suggest the parallel between one head of attention and a cortical area, and propose that a thalamo-cortico-thalamic pathway implements a computation akin to a multihead, unnormalized, linear self-attention block. Cross-attention corresponds to the key-value memory of one cortical area being used for retrieval by the query in another cortical area. Finally, as a first step towards a mechanistic theory of synaptic learning of cortical transformers, we derive the formal gradients of a typical loss function with respect to the parameters of such computation.
View on arXiv