[WIP] Tensor Field Networks and e3nn, Simplified

3 minute read

Published:

Preface

For the past few weeks, I’ve been extensively using the Tensor Field Network (TFN; Thomas et al., 2018) using the e3nn Python library (Geiger and Smidt, 2022). The TFN features a special message passing framework called the Tensor Product Convolution that combines different types of features in a complex but principlied way. This tensor product has always been a mystery but I think I understand what’s happening at a high level.

The TFN has been used successfully in many recent models owing to its high “geometric expressiveness” (Joshi et al., 2023). Most works, like DiffDock, EigenFold, and HarmonicFlow, use it in the AI4Science context where local and global atomic interactions matter.

The e3nn report and docs have been helpful but there’s a lot that’s left as an exercise to users. In this blog post, I hope to visually explain what’s going on under the hood, and how you can use the e3nn to build TFNs for your own applications.

I am not going to explain tensors and spherical harmonics in much detail. I assume you know how these parts come together at a high level. The e3nn report is a great introductory writeup about them. so do check that out.


The Real Tensor Product?

If you’ve taken an advanced college-level linear algebra course, you might have learned about the tensor product, $\otimes$. The “tensor product” I talk about in this blog post is slight misnomer: it has nothing to do with the true tensor product from college but has similar properties I’ll describe shortly. From here on out, any mention of “tensor product” refers to the underlying mechanism of the TFN and e3nn-based networks, unless otherwise stated. As someone who freshly finished advanced lin-alg at school, this was a major source of confusion and took me a while to reconcile.

On Tensors

In the context of machine learning, a tensor is seen as a multi-dimensional nested array or container. That’s partly true as a rough analogy but in terms of differential geometry,

The Tensor Product

In the TFN and e3nn, the tensor product is a special way to combine different types of tensors, ranging from simple scalars (numbers without any directional information), vectors (quantities with a magnitude and direction), to higher-order tensors describing different physical characteristics of a system. For instance, an atom can have an elemental identity, the atomic number, which is a scalar. The atom can also have Euclidean coordinates in 3D space, which is a vector.

How do we refer to different types of these tensors? This brings me to the concept of tensor order is exactly the hierarchy we’re looking for. The tensor order $l$ is a non-negative integer that refers to increasing orders of tensorial information: a $l=0$ tensor is a scalar, a $l=1$ tensor is a vector in $\mathbb{R}^3$, and so on. However, a $l=2$ tensor isn’t a matrix. Observe that for a given tensor order $l$, the associated tensors have a dimension of $2l + 1$. So scalars have $d=1$ (trivially), vectors have $d=3$ in $\mathbb{R}^3$, whatever comes next has $d=5$, and so on.