Jekyll2023-11-06T16:53:43+08:00https://rish-16.github.io/feed.xmlRishabh Anandpersonal descriptionRishabh Anandmail.rishabh.anand@gmail.com[WIP] Tensor Field Networks and e3nn, Simplified2023-11-06T00:00:00+08:002023-11-06T00:00:00+08:00https://rish-16.github.io/posts/e3nn-tut<h2 id="preface">Preface</h2>
<p>For the past few weeks, I’ve been extensively using the Tensor Field Network (TFN; <a href="https://arxiv.org/abs/1802.08219">Thomas et al., 2018</a>) using the <code class="language-plaintext highlighter-rouge">e3nn</code> Python library (<a href="https://arxiv.org/abs/2207.09453">Geiger and Smidt, 2022</a>). The TFN features a special message passing framework called the Tensor Product Convolution that combines different types of features in a complex but principlied way. This tensor product has always been a mystery but I think I understand what’s happening at a high level.</p>
<p>The TFN has been used successfully in many recent models owing to its high “geometric expressiveness” (<a href="https://arxiv.org/abs/2301.09308">Joshi et al., 2023</a>). Most works, like <a href="https://arxiv.org/abs/2210.01776">DiffDock</a>, <a href="https://arxiv.org/abs/2304.02198">EigenFold</a>, and <a href="https://arxiv.org/abs/2310.05764">HarmonicFlow</a>, use it in the AI4Science context where local and global atomic interactions matter.</p>
<p>The <code class="language-plaintext highlighter-rouge">e3nn</code> <a href="https://arxiv.org/abs/2207.09453">report</a> and <a href="https://docs.e3nn.org/en/latest/index.html">docs</a> have been helpful but there’s a lot that’s left as an exercise to users. In this blog post, I hope to <strong>visually explain</strong> what’s going on under the hood, and how you can use the <code class="language-plaintext highlighter-rouge">e3nn</code> to build TFNs for your own applications.</p>
<blockquote>
<p>I am not going to explain tensors and spherical harmonics in much detail. I assume you know how these parts come together at a high level. The <code class="language-plaintext highlighter-rouge">e3nn</code> report is a great introductory writeup about them. so do check that out.</p>
</blockquote>
<hr />
<h2 id="the-real-tensor-product">The <em>Real</em> Tensor Product?</h2>
<p>If you’ve taken an advanced college-level linear algebra course, you might have learned about the <a href="https://en.wikipedia.org/wiki/Tensor_product"><em>tensor product</em></a>, $\otimes$. The “tensor product” I talk about in this blog post is slight misnomer: it has nothing to do with the true tensor product from college but has similar properties I’ll describe shortly. From here on out, any mention of “tensor product” refers to the underlying mechanism of the TFN and <code class="language-plaintext highlighter-rouge">e3nn</code>-based networks, unless otherwise stated. As someone who freshly finished advanced lin-alg at school, this was a major source of confusion and took me a while to reconcile.</p>
<h2 id="on-tensors">On Tensors</h2>
<p>In the context of machine learning, a tensor is seen as a multi-dimensional nested array or container. That’s partly true as a rough analogy but in terms of differential geometry,</p>
<h2 id="the-tensor-product">The Tensor Product</h2>
<p>In the TFN and <code class="language-plaintext highlighter-rouge">e3nn</code>, the tensor product is a special way to combine different types of tensors, ranging from simple scalars (numbers without any directional information), vectors (quantities with a magnitude and direction), to higher-order tensors describing different physical characteristics of a system. For instance, an atom can have an elemental identity, the atomic number, which is a scalar. The atom can also have Euclidean coordinates in 3D space, which is a vector.</p>
<p>How do we refer to different types of these tensors? This brings me to the concept of <strong>tensor order</strong> is exactly the hierarchy we’re looking for. The tensor order $l$ is a non-negative integer that refers to increasing orders of tensorial information: a $l=0$ tensor is a scalar, a $l=1$ tensor is a vector in $\mathbb{R}^3$, and so on. However, a $l=2$ tensor isn’t a matrix. Observe that for a given tensor order $l$, the associated tensors have a dimension of $2l + 1$. So scalars have $d=1$ (trivially), vectors have $d=3$ in $\mathbb{R}^3$, whatever comes next has $d=5$, and so on.</p>Rishabh Anandmail.rishabh.anand@gmail.comPrefaceExpressive GNNs and How To Tame Them2022-05-08T00:00:00+08:002022-05-08T00:00:00+08:00https://rish-16.github.io/posts/expressive-gnns<blockquote>
<p>This blog post consists of research I’m currently engaged with at NUS. My findings span the past half-year’s worth of reading the literature and I’m excited to explain these concepts to you from scratch!</p>
</blockquote>
<hr />
<ul class="table-of-content" id="markdown-toc">
<li><a href="#foreword" id="markdown-toc-foreword">Foreword</a></li>
<li><a href="#graph-neural-networks" id="markdown-toc-graph-neural-networks">Graph Neural Networks</a> <ul>
<li><a href="#graphs" id="markdown-toc-graphs">Graphs</a></li>
<li><a href="#graph-neural-networks-1" id="markdown-toc-graph-neural-networks-1">Graph Neural Networks</a></li>
</ul>
</li>
<li><a href="#understanding-expressiveness" id="markdown-toc-understanding-expressiveness">Understanding Expressiveness</a> <ul>
<li><a href="#graph-isomorphism" id="markdown-toc-graph-isomorphism">Graph Isomorphism</a></li>
<li><a href="#weisfeiler-leman-gi-test" id="markdown-toc-weisfeiler-leman-gi-test">Weisfeiler-Leman GI Test</a></li>
<li><a href="#wl-test-and-graph-neural-networks" id="markdown-toc-wl-test-and-graph-neural-networks">WL Test and Graph Neural Networks</a></li>
<li><a href="#higher-order-structures" id="markdown-toc-higher-order-structures">Higher-Order Structures</a></li>
<li><a href="#weisfeiler-leman-hierarchy" id="markdown-toc-weisfeiler-leman-hierarchy">Weisfeiler-Leman Hierarchy</a></li>
</ul>
</li>
<li><a href="#conclusion" id="markdown-toc-conclusion">Conclusion</a></li>
</ul>
<h2 id="foreword">Foreword</h2>
<p>At NUS, I’m currently looking into Graph Neural Network expressiveness and I’m really excited about it. I’ve spoken to a few seniors in this area and have learned so much over the past few months. It’s an exciting line of research within Graph DL and I’m sure the pace will pick up soon. It’s more theoretical than practical currently but there’s definitely room for expansion. In fact, I’m currently looking into the practical aspects of expressiveness at NUS, which seems to be an underexplored niche.</p>
<p>Before we get into the meat of the topic, let’s get some preliminaries off the list first!</p>
<h2 id="graph-neural-networks">Graph Neural Networks</h2>
<h3 id="graphs">Graphs</h3>
<p>A graph is a data structure consists of nodes/vertices \(V\) and edges \(E\). These nodes represent entities or objects and the edges between them denote some relationship. These relationships are either unidirectional or bidirection. Let’s assume we’re working with undirected graphs for the rest of this post.</p>
<h3 id="graph-neural-networks-1">Graph Neural Networks</h3>
<p>A graph \(\mathcal{G}(V, E)\) consists of vertices (or nodes) \(v \in V\) and edges \(e_{ij} \in E \subseteq {V \times V}\) joining two nodes \(i\) and \(j\). \(e_{i,j} = 1\) if there’s a connection between nodes \(i\) and \(j\), 0 otherwise. The neighbourhood of a node \(i\), namely \(\mathcal{N}_i\), is a defined as the set of all nodes with outgoing edges to and incoming edges from \(i\); formally, \(\mathcal{N}_i = \{j : e_{ij} \in E\}\).</p>
<p>Each node \(i\) has an associated representation \(h_i^t \in \mathbb{R}^n\) and (discrete or continuous) label \(y\), for each GNN layer \(t \in \{1, \dots, T\}\). Each node \(i\) starts off with \(h_i^1 = x_i\), where \(x_i \in \mathbb{R}^n\) is the input features for the node. Edges \(e_{ij}\) can also have an associated representation \(a_{ij}^t \in \mathbb{R}^m\) depending on context. Each GNN layer \(t\) performs a single step of <em>Message Passing</em>. This involves combining the target node representation \(h^{t}_i\) with the node representations \(h_j^{t}\) from the neighbourhood \(\mathcal{N}_i\) (Equation \ref{mp}). Intuitively, at a layer \(t\), the GNN looks at the \(t\)-hop neighbourhood of \(i\), represented by a subtree rooted at node \(i\).</p>
\[\begin{equation} \label{mp}
h^{t+1}_i = \sigma(\psi(h^{t}_i,~\square(\{h^t_j : j \in \mathcal{N}_i\})))
\end{equation}\]
<p>Here, \(\psi\) is any affine transformation function (like a MLP) and \(\sigma\) is a non-linear, element-wise activation function (like <em>Sigmoid</em>, <em>ReLU</em>, or <em>Softmax</em>. \(\square\) is a permutation-invariant aggregation function that combines neighbouring node features; choices include “sum”, “max”, “min”, and “mean”. This aggregation function can be thought of as a hashing function that operates on multisets (sets with legal repetition) of node features. The same can be done to edge representation \(a_{ij}^t\) (Equation \ref{edge_mp}):</p>
\[\begin{equation} \label{edge_mp}
a_{ij}^{t+1} = \phi(a_{ij}^t, h^t_i, h^t_j)
\end{equation}\]
<p>\(\phi : \mathbb{R}^m \rightarrow \mathbb{R}^m\), parameterised by \(\theta\), takes in current edge feature \(a_{ij}^t\), and respective node features \(h^t_i\) and \(h^t_j\), to output the new edge representation \(a_{ij}^{t+1}\) for the next layer.</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/mp.png" width="100%" /></p>
<p><strong>Figure 1:</strong> A node \(i\) has a feature vector \(x_i \in \mathbb{R}^n\) (coloured envelope) and has a neighbourhood \(\mathcal{N}_i\) (left). A single round of Message Passing involves aggregating (collecting) representations from a target node’s neighbourhood and incorporating them into its own representation, for all nodes in the graph in parallel (right).</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/khop.png" width="100%" /></p>
<p><strong>Figure 2:</strong> <b>(a)</b> is the original graph. <b>(b)</b> is the rooted subtree of target node (green) at layer \(t=1\). <b>(c)</b> is the rooted subtree of target node (green) at layer \(t=2\). These rooted subtrees are multisets of node features.</p>
<hr />
<h2 id="understanding-expressiveness">Understanding Expressiveness</h2>
<p>Expressiveness refers to the ability of a GNN to discriminate two graphs. The inability to learn structural information from graphs results in <em>over-smoothing</em> – when two different nodes are assigned the same embedding representation in latent space, thereby being classified as the same. Therefore, structural awareness is important as it imbues inductive biases such as invariance to positions of nodes into the GNN, thereby allowing it to tell apart graphs. This brings us to the concept of <em>graph isomorphism</em>.</p>
<h3 id="graph-isomorphism">Graph Isomorphism</h3>
<p>Formally, two graphs are isomorphic if there exists a bijection (1:1 mapping) between its edges. This means the connectivities of the graphs should be alike. Trivially, if they are different, the two graphs are non-isomorphic.</p>
<h3 id="weisfeiler-leman-gi-test">Weisfeiler-Leman GI Test</h3>
<p>Two graphs are isomorphic if there exists a bijection between the vertex sets of both graphs. As such, the most notable algorithm for graph isomorphism is the <strong>Weisfeiler-Leman</strong> (WL) test. All nodes are assigned an initial <em>colour</em> (node-wise discrete label) and through iterations of naive vertex refinement, the colours of nodes are updated by incorporating it with the colours of neighbouring nodes. This is done using a hash function that takes in a multiset of neighbouring node colours that outputs a unique label for the next round of refinement. The test determines two graphs are non-isomorphic if the distribution of new colours differ at some iteration. To ensure the WL test can really tell apart graphs, we need to ensure the hash function is injective (unique mapping from multiset to hashed value).</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/hashfunc.png" width="100%" />
<img src="/images/2022-05-08-expressive-gnns/wldemo.jpg" width="100%" /></p>
<p><strong>Figure 3:</strong> The WL test performed on two graphs \(A\) and \(B\) that are isomorphic. Labels are the degrees of each node. The multiset hashing function is \(H(S) = \sum_{i \in S} i^2\). Similar degree nodes are coloured the same to show that they get mapped to the same hash value. Notice how the distributions of node labels stay the same for the two graphs for all \(n\) iterations. This indicates they are highly likely isomorphic.</p>
<hr />
<p>However, the WL test is necessary but insufficient to show graph isomorphism as there exist pairs of non-isomorphic graphs that are indistinguishable using the method. However, it has been a reliable test so far and works on most graphs.</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/wlfail.png" width="100%" /></p>
<p><strong>Figure 4:</strong> Examples of two graphs indistinguishable by the WL test. They produce similar distributions through the iterations of colour refinement. It’s catastrophic if datasets have graphs that exhibit similar properties and can’t be told apart.</p>
<hr />
<h3 id="wl-test-and-graph-neural-networks">WL Test and Graph Neural Networks</h3>
<p>In fact, we can draw parallels between the WL Test and a GNN. The aggregation function stays the same (simply collect neighbours’ labels) while the multiset hash function becomes the node readout function. Now, instead of node colours, we work with node features in GNNs. GNNs capable of discriminating two nodes (i.e., give them different representations in the embedding space) are <strong>at most as powerful</strong> as the WL Test (upper bound on expressiveness). Moreover, if we make the hash function/aggregation <u>injective</u>, the GNN can be <strong>as powerful as</strong> the WL Test (lower bound on expressiveness). This ensures two nodes are not assigned the same representation in the embedding space, thereby minimising the risk of classifying them as the same.</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/injective.png" width="100%" /></p>
<p><strong>Figure 5:</strong> Mathematically, an injective function ensures that every possible output (in the codomain) has at most one associated input (in the domain) that results in said output. So, by introducing injectivity into aggregation/readout function, we ensure every node’s post-aggregation multiset is mapped to a unique label for the next iteration. Note that we aren’t taking the nodes’ own labels in the respective multisets.</p>
<hr />
<h3 id="higher-order-structures">Higher-Order Structures</h3>
<p>So far, we’ve seen the WL test being used to discriminate between single nodes based on their colours/labels. There are more complex structures <em>within</em> graphs that can be used to tell apart said graphs. Examples of these higher-order structures include rooted subtrees, \(k\)-hop neighbourhoods, and pairs/tuples of connected nodes. The more expressive a GNN, the better it can make use of these structural hints (i.e., these higher-order structures) to discriminate graphs during training.</p>
<blockquote>
<p>In fact, certain works in the literature even augment GNNs with this higher-order structural information that cannot directly be inferred through the simple WL test.</p>
</blockquote>
<h3 id="weisfeiler-leman-hierarchy">Weisfeiler-Leman Hierarchy</h3>
<p>The vanilla WL test examines individual nodes and looks at their immediate 1-hop neighbourhood. GNNs capable of discerning graphs using this 1-hop neighbourhood are called 1-WL GNNs. More formally, we say the GNN is as <em>powerful</em> as 1-WL. We can generalise this to the \(k\)-hop neighbourhood where \(k \in \{2, 3, \dots\}\). This wider neighbourhood can be viewed as a larger multiset of neighbours and <em>their</em> neighbours, forming so-called higher-order structures. When a GNN is able to discern these higher-order structures, we call it a \(k\)-WL GNN, and claim the GNN is as powerful as \(k\)-WL. Expressiveness is measured using these different “levels” of \(k\)-WL, altogether forming the <strong>WL Hierarchy</strong>. A \(k\)-WL GNN is strictly weaker than a $(k+1)$-WL GNN in that there exists a graph that the latter can discriminate while the former cannot but the converse is not true.</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/1wl.png" width="65%" /></p>
<p><strong>Figure 6.1:</strong> Expressiveness is quantitatively defined using the WL Hierarchy. <strong>(top left)</strong> is the original graph. <strong>(top right)</strong> shows 1-WL expressiveness using the immediate 1-hop neighbourhood. The gray rectangles are the aggregated messages from the immediate neighbours. This is rather trivial.</p>
<p><img src="/images/2022-05-08-expressive-gnns/2wl.png" width="65%" /></p>
<p><strong>Figure 6.2:</strong> Here, I show 2-WL expressiveness using the 2-hop neighbourhood.</p>
<p><img src="/images/2022-05-08-expressive-gnns/3wl.png" width="65%" /></p>
<p><strong>Figure 6.3:</strong> Likewise, here, I show 3-WL expressiveness using the 3-hop neighbourhood.</p>
<p>The objective is discriminate both graphs (hence the \(\neq\)) despite their structures being the same. Altogether, they form rooted subtrees w.r.t. the target nodes being compared.</p>
<hr />
<p>In fact, regular Message Passing Neural Networks fail 1-WL because aggregation functions like “mean” and “max” cannot tell apart two non-identical graphs. To avoid such scenarios, we introduce injectivity to the aggregation (multiset hashing) function; the “sum” aggregator is one such example.</p>
<hr />
<p><img src="/images/2022-05-08-expressive-gnns/aggrfail.png" width="100%" /></p>
<p><strong>Figure 7:</strong> The graphs on the left cannot be discriminated using the “max” aggregator. The graphs on the right cannot be discriminated using the “max” and “mean” aggregators. This is because these functions are not injective by nature.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The theoretical research community has more or less moved away from standard GNNs towards expressive GNNs like those mentioned above. The obvious benefits include better structural awareness, which is paramount for real-life problems like protein studies, molecule interaction modelling, and social media analysis. I hope this blogpost shares some exciting insights about this new family of expressive GNNs. In terms of what’s to come, I believe we need more <em>benchmarking efforts</em> to really compare these expressive GNNs with one another. This means coming up with new, dedicated datasets, both real-life and synthetic.</p>
<p>Let me know if you want an in-depth review of notable works from the literature; there have been a lot of exciting SOTA works coming out lately. Also let me know if you want to access the extensive reading list I used to get me up-to-speed.</p>
<blockquote>
<p>I look forward to sharing more with you in time! Lots of exciting work to be done and lots of learnings and takeaways in store. Stay tuned!</p>
<p>Till then, I’ll see you in the next post :D</p>
</blockquote>Rishabh Anandmail.rishabh.anand@gmail.comThis blog post consists of research I’m currently engaged with at NUS. My findings span the past half-year’s worth of reading the literature and I’m excited to explain these concepts to you from scratch!Recent Advances in Deep Learning for Routing Problems2022-03-22T00:00:00+08:002022-03-22T00:00:00+08:00https://rish-16.github.io/posts/routing-dl<blockquote>
<p>This blog post was written alongside Chaitanya Joshi, a good friend, senior-mentor, and current PhD student at Cambridge University. We are happy to announce this blog post was accepted (Top 50%) into the ICLR Blog Post Track 2022!!!</p>
</blockquote>
<hr />
<p><strong>TL;DR</strong> Developing neural network-driven solvers for combinatorial optimization problems such as the Travelling Salesperson Problem have seen a surge of academic interest recently. This blogpost presents a <strong>Neural Combinatorial Optimization</strong> pipeline that unifies several recently proposed model architectures and learning paradigms into one single framework. Through the lens of the pipeline, we analyze recent advances in deep learning for routing problems and provide new directions to stimulate future research towards practical impact.</p>
<ul class="table-of-content" id="markdown-toc">
<li><a href="#background-on-combinatorial-optimization-problems" id="markdown-toc-background-on-combinatorial-optimization-problems">Background on Combinatorial Optimization Problems</a> <ul>
<li><a href="#tsp-and-routing-problems" id="markdown-toc-tsp-and-routing-problems">TSP and Routing Problems</a></li>
<li><a href="#deep-learning-to-solve-routing-problems" id="markdown-toc-deep-learning-to-solve-routing-problems">Deep Learning to solve Routing Problems</a></li>
<li><a href="#neural-combinatorial-optimization" id="markdown-toc-neural-combinatorial-optimization">Neural Combinatorial Optimization</a></li>
</ul>
</li>
<li><a href="#unified-neural-combinatorial-optimization-pipeline" id="markdown-toc-unified-neural-combinatorial-optimization-pipeline">Unified Neural Combinatorial Optimization Pipeline</a> <ul>
<li><a href="#1-defining-the-problem-via-graphs" id="markdown-toc-1-defining-the-problem-via-graphs">(1) Defining the problem via graphs</a></li>
<li><a href="#2-obtaining-latent-embeddings-for-graph-nodes-and-edges" id="markdown-toc-2-obtaining-latent-embeddings-for-graph-nodes-and-edges">(2) Obtaining latent embeddings for graph nodes and edges</a></li>
<li><a href="#3--4-converting-embeddings-into-discrete-solutions" id="markdown-toc-3--4-converting-embeddings-into-discrete-solutions">(3 + 4) Converting embeddings into discrete solutions</a></li>
<li><a href="#5-training-the-model" id="markdown-toc-5-training-the-model">(5) Training the model</a></li>
</ul>
</li>
<li><a href="#characterizing-prominent-papers-via-the-pipeline" id="markdown-toc-characterizing-prominent-papers-via-the-pipeline">Characterizing Prominent Papers via the Pipeline</a></li>
<li><a href="#recent-advances-and-avenues-for-future-work" id="markdown-toc-recent-advances-and-avenues-for-future-work">Recent Advances and Avenues for Future Work</a> <ul>
<li><a href="#leveraging-equivariance-and-symmetries" id="markdown-toc-leveraging-equivariance-and-symmetries">Leveraging Equivariance and Symmetries</a></li>
<li><a href="#improved-graph-search-algorithms" id="markdown-toc-improved-graph-search-algorithms">Improved Graph Search Algorithms</a></li>
<li><a href="#learning-to-improve-sub-optimal-solutions" id="markdown-toc-learning-to-improve-sub-optimal-solutions">Learning to Improve Sub-optimal Solutions</a></li>
<li><a href="#learning-paradigms-that-promote-generalization" id="markdown-toc-learning-paradigms-that-promote-generalization">Learning Paradigms that Promote Generalization</a></li>
<li><a href="#improved-evaluation-protocols" id="markdown-toc-improved-evaluation-protocols">Improved Evaluation Protocols</a></li>
</ul>
</li>
<li><a href="#summary" id="markdown-toc-summary">Summary</a></li>
</ul>
<hr />
<h2 id="background-on-combinatorial-optimization-problems">Background on Combinatorial Optimization Problems</h2>
<p><strong>Combinatorial Optimization</strong> is a practical field in the intersection of mathematics and computer science that aims to solve constrained optimization problems which are NP-Hard. <strong>NP-Hard problems</strong> are challenging as exhaustively searching for their solutions is beyond the limits of modern computers. It is impossible to solve NP-Hard problems optimally at large scales.</p>
<p><strong>Why should we care?</strong> Because robust and reliable approximation algorithms to popular problems have immense practical applications and are the backbone of modern industries. For example, the <strong>Travelling Salesperson Problem</strong> (TSP) is the most popular Combinatorial Optimization Problems (COPs) and comes up in applications as diverse as logistics and scheduling to genomics and systems biology.</p>
<blockquote>
<p>The Travelling Salesperson Problem is so famous, or notorious, that it even has an <a href="https://xkcd.com/399/">xkcd comic</a> dedicated to it!</p>
</blockquote>
<h3 id="tsp-and-routing-problems">TSP and Routing Problems</h3>
<p>TSP is also a classic example of a <strong>Routing Problem</strong> – Routing Problems are a class of COPs that require a sequence of nodes (e.g., cities) or edges (e.g., roads between cities) to be traversed in a specific order while fulfilling a set of constraints or optimising a set of variables. TSP requires a set of edges to be traversed in an order that ensures all nodes are visited exactly once. In the algorithmic sense, the optimal “tour” for our salesperson is a sequence of selected edges that provides the minimal distance or time taken over a Hamiltonian cycle, see Figure 1 for an illustration.</p>
<figure><center>
<img src="/images/routing-dl/tsp-gif.gif" width="60%" />
<figcaption><b>Figure 1:</b> TSP asks the following question: Given a list of cities and the distances between each pair of cities, what is the <b>shortest possible route</b> that a salesperson can take to <b>visit each city</b> and <b>returns to the origin city</b>?
(Source: <a href="http://mathgifs.blogspot.com/2014/03/the-traveling-salesman.html">MathGifs</a>)</figcaption>
</center></figure>
<p>In real-world and practical scenarios, Routing Problems, or Vehicle Routing Problems (VRPs), can involve challenging constraints beyond the somewhat <em>vanilla</em> TSP; they are generalisations of TSP. For example, the <strong>TSP with Time Windows</strong> (TSPTW) adds a “time window” contraint to nodes in a TSP graph. This means certain nodes can either be active or inactive at a given time, i.e., they can only be visited during certain intervals. Another variant, the <strong>Capacitated Vehicle Routing Problem</strong> (CVRP) aims to find the optimal routes for a fleet of vehicles (i.e., multiple salespersons) visiting a set of customers (i.e., cities), with each vehicle having a maximum carrying capacity.</p>
<figure><center>
<img src="/images/routing-dl/vrps.png" width="50%" />
<figcaption><b>Figure 2:</b> TSP and the associated class of Vehicle Routing Problems. VRPs can be characterized by their constraints, and this figure presents the relatively well-studied ones. There could be VRPs in the wild with <b>more complex</b> and <b>non-standard constraints</b>! (Source: adapted from <a href="https://ieeexplore.ieee.org/abstract/document/6887420">Benslimane and Benadada, 2014</a>)</figcaption>
</center></figure>
<h3 id="deep-learning-to-solve-routing-problems">Deep Learning to solve Routing Problems</h3>
<p>Developing reliable algorithms and solvers for routing problems such as VRPs requires significant <strong>expert intuition</strong> and years of <strong>trial-and-error</strong>. For example, the state-of-the-art TSP solver, <strong>Concorde</strong>, leverages over 50 years of research on linear programming, cutting plane algorithms and branch-and-bound; here is an <a href="https://www.youtube.com/watch?v=q8nQTNvCrjE">inspiring video</a> on its history. Concorde can find optimal solutions up to tens of thousands of nodes, but with extremely long execution time. As you can imagine, designing algorithms for complex VRPs is even more challegning and time consuming, especially with real-world constraints such as capacities or time windows in the mix.</p>
<p>This has led the machine learning community to ask the following question:</p>
<p><strong>Can we use deep learning to automate and augment expert intuition required for solving COPs?</strong></p>
<blockquote>
<p>See this masterful survey from Mila for more in-depth motivation: [<a href="https://arxiv.org/abs/1811.06128">Bengio et al., 2020</a>].</p>
</blockquote>
<h3 id="neural-combinatorial-optimization">Neural Combinatorial Optimization</h3>
<p><a href="https://www.chaitjo.com/post/neural-combinatorial-optimization/">Neural Combinatorial Optimization</a> is an attempt to use <strong>deep learning as a hammer</strong> to hit the <strong>COP nails</strong>. Neural networks are trained to produce approximate solutions to COPs by directly learning from problem instances themselves. This line of research started at Google Brain with the seminal <a href="https://arxiv.org/abs/1506.03134">Seq2seq Pointer Networks</a> and <a href="https://arxiv.org/abs/1611.09940">Neural Combinatorial Optimization with RL</a> papers. Today, <a href="https://arxiv.org/abs/2102.09544">Graph Neural Networks</a> are usually the architecture of choice at the core of deep learning-driven solvers as they tackle the graph structure of these problems.</p>
<p>Neural Combinatorial Optimization aims to improve over traditional COP solvers in the following ways:</p>
<ul>
<li>
<p><strong>No handcrafted heuristics.</strong> Instead of application experts manually designing heuristics and rules, neural networks learn them via imitating an optimal solver or via reinforcement learning (we describe a pipeline for this in the next section).</p>
</li>
<li>
<p><strong>Fast inference on GPUs.</strong> Traditional solvers can often have prohibitive execution time for large-scale problems, e.g., Concorde took 7.5 months to solve the largest TSP with 109,399 nodes. On the other hand, once a neural network has been trained to approximately solve a COP, they have significantly favorable time complexity and can be parallelized via GPUs. This makes them highly desirable for real-time decision-making problems, especially routing problems.</p>
</li>
<li>
<p><strong>Tackling novel and under-studied COPs.</strong> The development of problem-specific COP solvers for novel or understudied problems that have esoteric constraints can be significantly sped up via neural combinatorial optimization. Such problems often arise in scientific discovery or computer architecture, e.g., an exciting success story is <a href="https://www.nature.com/articles/s41586-021-03544-w">Google’s chip design system</a> that will power the next generation of TPUs. You read that right – <strong>the next TPU chip for running neural networks has been designed by a neural network!</strong></p>
</li>
</ul>
<hr />
<h2 id="unified-neural-combinatorial-optimization-pipeline">Unified Neural Combinatorial Optimization Pipeline</h2>
<p>Using TSP as a canonical example, we now present a generic <strong>neural combinatorial optimization pipeline</strong> that can be used to characterize modern deep learning-driven approaches to several routing problems.</p>
<p>State-of-the-art approaches for TSP take the raw coordinates of cities as input and leverage <strong>GNNs</strong> or <strong>Transformers</strong> combined with classical <strong>graph search</strong> algorithms to constructively build approximate solutions. Architectures can be broadly classified as: (1) <strong>autoregressive</strong> approaches, which build solutions in a step-by-step fashion; and (2) <strong>non-autoregressive</strong> models, which produce the solution in one shot. Models can be trained to <strong>imitate optimal solvers</strong> via supervised learning or by minimizing the length of TSP tours via <strong>reinforcement learning</strong>.</p>
<figure><center>
<img src="/images/routing-dl/pipeline-box.png" width="75%" />
<figcaption><b>Figure 3:</b> Neural combinatorial optimization pipeline (Source: <a href="https://arxiv.org/abs/2006.07054">Joshi et al., 2021</a>).</figcaption>
</center></figure>
<p>The 5-stage pipeline from <a href="https://arxiv.org/abs/2006.07054">Joshi et al., 2021</a> brings together prominent model architectures and learning paradigms into <strong>one unified framework</strong>. This will enable us to dissect and analyze recent developments in deep learning for routing problems, and provide new directions to stimulate future research.</p>
<h3 id="1-defining-the-problem-via-graphs">(1) Defining the problem via graphs</h3>
<figure><center>
<img src="/images/routing-dl/pipeline-1.png" width="60%" />
<figcaption><b>Figure 4: Problem Definition:</b> TSP is formulated via a fully-connected graph of cities/nodes, which can be sparsified further.</figcaption>
</center></figure>
<p>TSP is formulated via a fully-connected graph where <strong>nodes</strong> correspond to <strong>cities</strong> and <strong>edges</strong> denote <strong>roads</strong> between them. The graph can be sparsified via heuristics such as k-nearest neighbors. This enables models to scale up to large instances where pairwise computation for all nodes is intractable [<a href="https://arxiv.org/abs/1704.01665">Khalil et al., 2017</a>] or learn faster by reducing the search space [<a href="https://arxiv.org/abs/1906.01227">Joshi et al., 2019</a>].</p>
<h3 id="2-obtaining-latent-embeddings-for-graph-nodes-and-edges">(2) Obtaining latent embeddings for graph nodes and edges</h3>
<figure><center>
<img src="/images/routing-dl/pipeline-2.png" width="60%" />
<figcaption><b>Figure 5: Graph Embedding:</b> Embeddings for each graph node are obtained using a <b>Graph Neural Network</b> encoder, which builds local structural features via recursively aggregating features from each node's neighbors.</figcaption>
</center></figure>
<p>A GNN or Transformer encoder computes <strong>hiddden representations</strong> or embeddings for each node and/or edge in the input TSP graph. At each layer, nodes gather features from their neighbors to represent <strong>local graph structure</strong> via recursive message passing. Stacking $L$ layers allows the network to build representations from the $L$-hop neighborhood of each node.</p>
<p><strong>Anisotropic</strong> and <strong>attention-based GNNs</strong> such as Transformers [<a href="https://hanalog.polymtl.ca/wp-content/uploads/2018/11/cpaior-learning-heuristics-6.pdf">Deudon et al., 2018</a>, <a href="https://arxiv.org/abs/1803.08475">Kool et al., 2019</a>] and Gated Graph ConvNets [<a href="https://arxiv.org/abs/1906.01227">Joshi et al., 2019</a>] have emerged as the default choice for encoding routing problems. The attention mechanism during neighborhood aggregation is critical as it allows each node to weigh its neighbors based on their <strong>relative importance</strong> for solving the task at hand.</p>
<blockquote>
<p>Importantly, the Transformer encoder can be seen as an attentional GNN, i.e., <a href="https://petar-v.com/GAT/">Graph Attention Network (GAT)</a>, on a fully-connected graph. See <a href="https://thegradient.pub/transformers-are-graph-neural-networks/">this blogpost</a> for an intuitive explanation.</p>
</blockquote>
<h3 id="3--4-converting-embeddings-into-discrete-solutions">(3 + 4) Converting embeddings into discrete solutions</h3>
<figure><center>
<img src="/images/routing-dl/pipeline-3.png" width="70%" />
<figcaption><b>Figure 5: Solution Decoding and Search:</b> Probabilities are assigned to each node or edge for <b>belonging to the solution set</b> (here, an MLP makes a prediction per edge to obtain a 'heatmap' of edge probabilities), and then converted into <b>discrete decisions</b> through classical graph search techniques such as greedy search or beam search.</figcaption>
</center></figure>
<p>Once the nodes and edges of the graph have been encoded into latent representations, we must decode them into discrete TSP solutions.
This is done via a two-step process: Firstly, probabilities are assigned to each node or edge for belonging to the solution set, either independent of one-another (i.e., <strong>Non-autoregressive decoding</strong>) or conditionally through graph traversal (i.e., <strong>Autoregressive decoding</strong>). Next, the predicted probabilities are converted into discrete decisions through classical <strong>graph search techniques</strong> such as greedy search or beam search guided by the probabilistic predictions (more on graph search later, when we discuss recent trends and future directions).</p>
<p>The choice of decoder comes with tradeoffs between <strong>data-efficiency</strong> and <strong>efficiency of implementation</strong>:
Autoregressive decoders [<a href="https://arxiv.org/abs/1803.08475">Kool et al., 2019</a>] cast TSP as a Seq2Seq or <strong>language translation task</strong> from a set of unordered cities to an ordered tour. They explicitly model the <strong>sequential inductive bias</strong> of routing problems through step-by-step selection of one node at a time. On the other hand, Non-autoregressive decoders [<a href="https://arxiv.org/abs/1906.01227">Joshi et al., 2019</a>] cast TSP as the task of producing <strong>edge probability heatmaps</strong>. The NAR approach is significantly faster and better suited for real-time inference as it produces predictions in <strong>one shot</strong> instead of step-by-step. However, it ignores the sequential nature of TSP, and may be less efficient to train when compared fairly to AR decoding [<a href="https://arxiv.org/abs/2006.07054">Joshi et al., 2021</a>].</p>
<h3 id="5-training-the-model">(5) Training the model</h3>
<p>Finally, the entire encoder-decoder model is trained in an <strong>end-to-end</strong> fashion, exactly like deep learning models for computer vision or natural language processing. In the simplest case, models can be trained to produce close-to-optimal solutions via <strong>imitating an optimal solver</strong>, i.e., via supervised learning. For TSP, the <strong>Concrode</strong> solver is used to generate labelled training datasets of optimal tours for millions of random instances. Models with AR decoders are trained via teacher-forcing to output the optimal sequence of tour nodes [<a href="https://arxiv.org/abs/1506.03134">Vinyals et al., 2015</a>], while those with NAR decoders are trained to identify edges traversed during the tour from non-traversed edges [<a href="https://arxiv.org/abs/1906.01227">Joshi et al., 2019</a>].</p>
<p>However, creating labelled datasets for supervised learning is an <strong>expensive</strong> and <strong>time-consuming process</strong>. Especially for very large problem instances, the exactness guarentees of optimal solvers may no longer materialise, leading to inexact solutions being used for supervised training. This is far from ideal from both practical and theoretical standpoints [<a href="https://arxiv.org/abs/2002.09398">Yehuda et al., 2020</a>].</p>
<p><strong>Reinforcement learning</strong> is a elegant alternative in the absence of groundtruth solutions, as is often the case for understudied problems. As routing problems generally require sequential decision making to <strong>minimize a problem-specific cost functions</strong> (e.g., the tour length for TSP), they can elegantly be cast in the RL framework which trains an agent to <strong>maximize a reward</strong> (the negative of the cost function). Models with AR decoders can be trained via standard policy gradient algorithms [<a href="https://arxiv.org/abs/1803.08475">Kool et al., 2019</a>] or Q-Learning [<a href="https://arxiv.org/abs/1704.01665">Khalil et al., 2017</a>].</p>
<hr />
<h2 id="characterizing-prominent-papers-via-the-pipeline">Characterizing Prominent Papers via the Pipeline</h2>
<p>We can characterize prominent works in deep learning for TSP through the 5-stage pipeline. Recall that the pipeline consists of: (1) Problem Definition → (2) Graph Embedding → (3) Solution Decoding → (4) Solution Search → (5) Policy Learning. Starting from the Pointer Networks paper by Oriol Vinyals and collaborators, the following <strong>table</strong> highlights in <span style="color:red">Red</span> the major innovations and contributions for several notable and early papers.</p>
<table>
<thead>
<tr>
<th>Paper</th>
<th>Definition</th>
<th>Graph Embedding</th>
<th>Solution Decoding</th>
<th>Solution Search</th>
<th>Policy Learning</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://arxiv.org/abs/1506.03134">Vinyals et al., 2015</a></td>
<td>Sequence</td>
<td><span style="color:red">Seq2Seq</span></td>
<td><span style="color:red">Attention (AR)</span></td>
<td>Beam Search</td>
<td>Immitation (SL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1611.09940">Bello et al., 2017</a></td>
<td>Sequence</td>
<td>Seq2seq</td>
<td>Attention (AR)</td>
<td>Sampling</td>
<td><span style="color:red">Actor-critic (RL)</span></td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1704.01665">Khalil et al., 2017</a></td>
<td><span style="color:red">Sparse Graph</span></td>
<td><span style="color:red">Structure2vec</span></td>
<td>MLP (AR)</td>
<td>Greedy Search</td>
<td><span style="color:red">DQN (RL)</span></td>
</tr>
<tr>
<td><a href="https://hanalog.polymtl.ca/wp-content/uploads/2018/11/cpaior-learning-heuristics-6.pdf">Deudon et al., 2018</a></td>
<td>Full Graph</td>
<td><span style="color:red">Transformer Encoder</span></td>
<td>Attention (AR)</td>
<td>Sampling + <span style="color:red">Local Search</span></td>
<td>Actor-critic (RL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1803.08475">Kool et al., 2019</a></td>
<td>Full Graph</td>
<td><span style="color:red">Transformer Encoder</span></td>
<td>Attention (AR)</td>
<td>Sampling</td>
<td><span style="color:red">Rollout (RL)</span></td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1906.01227">Joshi et al., 2019</a></td>
<td>Sparse Graph</td>
<td><span style="color:red">Residual Gated GCN</span></td>
<td><span style="color:red">MLP Heatmap (NAR)</span></td>
<td>Beam Search</td>
<td>Immitation (SL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1911.04936">Ma et al., 2020</a></td>
<td>Full Graph</td>
<td>GCN</td>
<td><span style="color:red">RNN + Attention (AR)</span></td>
<td>Sampling</td>
<td>Rollout (RL)</td>
</tr>
</tbody>
</table>
<hr />
<h2 id="recent-advances-and-avenues-for-future-work">Recent Advances and Avenues for Future Work</h2>
<p>With the unified 5-stage pipeline in place, let us highlight some <strong>recent advances</strong> and <strong>trends</strong> in deep learning for routing problems. We will also provide some future research directions with a focus on improving generalization to large-scale and real-world instances.</p>
<h3 id="leveraging-equivariance-and-symmetries">Leveraging Equivariance and Symmetries</h3>
<p>One of the most influential early works, the autoregressive Attention Model [<a href="https://arxiv.org/abs/1803.08475">Kool et al., 2019</a>], considers TSP as a Seq2Seq language translation problem and sequentially constructs TSP tours as permutations of cities. One immediate drawback of this formulation is that it does not consider the <strong>underlying symmetries of routing problems</strong>.</p>
<figure><center>
<img src="/images/routing-dl/pomo.png" width="75%" />
<figcaption><b>Figure 6:</b> In general, a TSP has one unique optimal solution (L). However, under the autoregressive formulation when a solution is represented as a sequence of nodes, <b>multiple optimal permutations</b> exist (R). (Source: <a href="https://arxiv.org/abs/2010.16011">Kwon et al., 2020</a>)</figcaption>
</center></figure>
<p><strong>POMO: Policy Optimization with Multiple Optima</strong> [<a href="https://arxiv.org/abs/2010.16011">Kwon et al., 2020</a>] proposes to leverage invariance to the starting city in the constructive autoregressive formulation. They train the same Attention Model, but with a new reinforcement learning algorithm (step 5 in the pipeline) which exploits the existence of multiple optimal tour permutations.</p>
<figure><center>
<img src="/images/routing-dl/equivariance.png" width="75%" />
<figcaption><b>Figure 7:</b> TSP solutions remain unchanged under the <b>Euclidean symmtery group</b> of rotations, reflections, and translations to the city coordinates. Incorporating these symetries into the model may be a principled approach to tackling large-scale TSPs.</figcaption>
</center></figure>
<p>Similarly, a very recent ugrade of the Attention model by <a href="https://arxiv.org/abs/2110.03595">Ouyang et al., 2021</a> considers invariance with respect to <strong>rotations, reflections,</strong> and <strong>translations</strong> (i.e., the Euclidean symmetry group) of the input city coordinates. They propose an autoregressive approach while ensuring invariance by performing data augmentation during the problem definition stage (pipeline step 1) and using relative coordinates during graph encoding (pipeline step 2). Their approach shows particularly strong results on zero-shot generalization from random instances to the real-world TSPLib benchmark suite.</p>
<p>Future work may follow the <a href="https://geometricdeeplearning.com/"><strong>Geometric Deep Learning (GDL)</strong></a> blueprint for architecture design. GDL tells us to explicitly think about and incorporate the symmetries and inductive biases that govern the data or problem at hand. As routing problems are <strong>embedded in euclidean coordinates</strong> and the <strong>routes are cyclical</strong>, incorporating these contraints directly into the model architectures or learning paradigms may be a principled approach to improving generalization to large-scale instances greater than those seen during training.</p>
<h3 id="improved-graph-search-algorithms">Improved Graph Search Algorithms</h3>
<p>Another influential research direction has been the one-shot non-autoregressive Graph ConvNet approach [<a href="https://arxiv.org/abs/1906.01227">Joshi et al., 2019</a>]. Several recent papers have proposed to retain the same Gated GCN encoder (pipeline step 2) while replacing the beam search component (pipeline step 4) with <strong>more powerful</strong> and <strong>flexible graph search algorithms</strong>, e.g., Dynamic Programming [<a href="https://arxiv.org/abs/2102.11756">Kool et al., 2021</a>] or Monte-Carlo Tree Search (MCTS) [<a href="https://arxiv.org/abs/2012.10658">Fu et al., 2020</a>].</p>
<figure><center>
<img src="/images/routing-dl/heatmaps.png" width="75%" />
<figcaption><b>Figure 8:</b> The Gated GCN encoder <a href="https://arxiv.org/abs/1906.01227">[Joshi et al., 2019]</a> can be used to produce <b>edge prediction 'heatmaps'</b> (in transparent red color) for TSP, CVRP, and TSPTW. These can be further processed by <a href="https://arxiv.org/abs/2102.11756">DP</a> or <a href="https://arxiv.org/abs/2012.10658">MCTS</a> to output routes (in solid colors). The GCN essentially reduces the solution search space for sophisticated search algorithms which may have been intractable when searching over all possible routes. (Source: <a href="https://arxiv.org/abs/2102.11756">Kool et al., 2021</a>)</figcaption>
</center></figure>
<p>The <a href="https://arxiv.org/abs/2012.10658">GCN + MCTS framework</a> by Fu et al. in particular has a very interesting approach to <strong>training models efficiently on trivially small TSP</strong> and successfully <strong>transferring the learnt policy to larger graphs</strong> in a zero-shot fashion (something that the original GCN + Beam Search by Joshi et al. struggled with). They ensure that the predictions of the GCN encoder generalize from small to large TSP by updating the problem definition (pipeline step 1): large problem instances are represented as many smaller sub-graphs which are of the same size as the training graphs for the GCN, and then merge the GCN edge predictions before performing MCTS.</p>
<figure><center>
<img src="/images/routing-dl/sample-merge.png" width="70%" />
<figcaption><b>Figure 9:</b> The GCN + MCTS framework <a href="https://arxiv.org/abs/2012.10658">[Fu et al., 2020]</a> represents large TSPs as a set of <b>small sub-graphs</b> which are of the same size as the graphs used for training the GCN. Sub-graph edge heatmaps predicted by the GCN are merged together to obtain the heatmap for the full graph. This <b>divide-and-conquer approach</b> ensures that the embeddings and predictions made by the GCN generalize well from smaller to larger instances. (Source: <a href="https://arxiv.org/abs/2012.10658">Fu et al., 2020</a>)</figcaption>
</center></figure>
<p>Originally proposed by <a href="https://openreview.net/forum?id=B1jscMbAW">Nowak et al., 2018</a>, this <b>divide-and-conquer strategy</b> ensures that the embeddings and predictions made by GNNs generalize well from smaller to larger TSP instances up to 10,000 nodes. Fusing GNNs, divide-and-conquer, and search strategies has similarly shown promising results for tackling large-scale CVPRs up to 3000 nodes [<a href="https://arxiv.org/abs/2107.04139">Li et al., 2021</a>].</p>
<p>Overall, this line of work suggests that <strong>stronger coupling</strong> between the design of both the <strong>neural</strong> and <strong>symbolic/search</strong> components of models is essential for out-of-distribution generalization [<a href="https://arxiv.org/abs/2003.00330">Lamb et al., 2020</a>]. However, it is also worth noting that designing highly customized and parallelized implementations of graph search on GPUs may be challenging for each new problem.</p>
<h3 id="learning-to-improve-sub-optimal-solutions">Learning to Improve Sub-optimal Solutions</h3>
<p>Recently, a number of papers have explored an alternative to constructive AR and NAR decoding schemes which involves <strong>learning to iteratively improve (sub-optimal) solutions</strong> or <strong>learning to perform local search</strong>, starting with <a href="https://arxiv.org/abs/1810.00337">Chen et al., 2019</a> and <a href="https://arxiv.org/abs/1912.05784">Wu et al., 2021</a>. Other notable papers include the works of <a href="https://ojs.aaai.org/index.php/AAAI/article/view/16484">Cappart et al., 2021</a>, <a href="https://arxiv.org/abs/2004.01608">da Costa et al., 2020</a>, <a href="https://arxiv.org/abs/2110.02544">Ma et al., 2021</a>, <a href="https://arxiv.org/abs/2110.07983">Xin et al., 2021</a>, and <a href="https://arxiv.org/abs/2110.05291">Hudson et al., 2021</a>.</p>
<figure><center>
<img src="/images/routing-dl/cyclic-pe.png" width="75%" />
<figcaption><b>Figure 10:</b> Architectures which learn to improve sub-optimal TSP solutions by guiding decisions within local search algorithms. (a) The original Transformer encoder-decoder architecture <a href="https://arxiv.org/abs/1912.05784">[Wu et al., 2021]</a> which used <b>sinusoidal positional encodings</b> to represent the current sub-optimal tour permutation; (b) <a href="https://arxiv.org/abs/2110.02544">Ma et al., 2021</a>'s upgrade through the lens of symmetry: the Dual-aspect Transformer encoder-decoder with <b>learnable positional encodings</b> which capture the cyclic nature of TSP tours; (c) Visualizations of sinusoidal vs. cyclical positional encodings.</figcaption>
</center></figure>
<p>In all these works, since deep learning is used to <strong>guide decisions</strong> within classical local search algorithms (which are designed to work regardless of problem scale), this approach implicitly leads to <strong>better zero-shot generalization</strong> to larger problem instances compared to the constructive approaches. This is a very desirable property for practical implementations, as it may be intractable to train on very large or real-world TSP instances.</p>
<p>Notably, <strong>NeuroLKH</strong> [<a href="https://arxiv.org/abs/2110.07983">Xin et al., 2021</a>] uses edge probability heatmaps produced via GNNs to improve the <strong>classical Lin-Kernighan-Helsgaun algorithm</strong> and demonstrates strong zero-shot generalization to TSP with 5000 nodes as well as across TSPLib instances.</p>
<blockquote>
<p>For the interested reader, DeepMind’s <a href="https://arxiv.org/abs/2105.02761">Neural Algorithmic Reasoning</a> research program offers a unique meta-perspective on the intersection of neural networks with classical algorithms.</p>
</blockquote>
<p>A limitation of this line of work is the prior need for <strong>hand-designed local search algorithms</strong>, which may be missing for novel or understudied problems. On the other hand, constructive approaches are arguably easier to adapt to new problems by enforcing constraints during the solution decoding and search procedure.</p>
<h3 id="learning-paradigms-that-promote-generalization">Learning Paradigms that Promote Generalization</h3>
<p>Future work could look at <strong>novel learning paradigms</strong> (pipeline step 5) which explicitly focus on generalization beyond supervised and reinforcement learning, e.g., <a href="https://openreview.net/forum?id=90JprVrJBO">Hottung et al., 2020</a> explored autoencoder objectives to learn a continuous space of routing problem solutions.</p>
<p>At present, most papers propose to train models efficiently on trivially small and random TSPs, then transfer the learnt policy to larger graphs and real-world instances in a <strong>zero-shot</strong> fashion. The logical next step is to fine-tune the model on a small number of specifc problem instances. <a href="https://arxiv.org/abs/2106.05126">Hottung et al., 2021</a> take a first step towards this by proposing to finetune a subset of model paramters for each specific problem instance via active search. In future work, it may be interesting to explore <strong>fine-tuning as a meta-learning problem</strong>, wherein the goal is to train model parameters specifically for fast adaptation to new data distributions and problems.</p>
<p>Another interesting direction could explore <strong>tackling understudied routing problems</strong> with challenging constraints via multi-task pre-training on popular routing problems such as TSP and CVPR, followed by problem-specific finetuning. Similar to <strong>language modelling as a pre-training objective</strong> in <a href="https://ruder.io/nlp-imagenet/">Natural Language Processing</a>, the goal of pre-training for routing would be to learn generally useful latent representations that can transfer well to novel routing problems.</p>
<h3 id="improved-evaluation-protocols">Improved Evaluation Protocols</h3>
<p>Beyond algorithmic innovations, there have been repeated calls from the community for <strong>more realistic evaluation protocols</strong> which can lead to advances on real-world routing problems and adoption by industry [<a href="https://arxiv.org/abs/1909.13121">Francois et al., 2019</a>, <a href="https://arxiv.org/abs/2002.09398">Yehuda et al., 2020</a>]. Most recently, <a href="https://arxiv.org/abs/2109.13983">Accorsi et al., 2021</a> have provided an authoritative set of <strong>guidelines for experiment design</strong> and <strong>comparisons</strong> to classical Operations Research (OR) techniques. They hope that fair and rigorous comparisons on <strong>standardized benchmarks</strong> will be the first step towards the integration of deep learning techniques into industrial routing solvers.</p>
<p>In general, it is encouraging to see recent papers move beyond showing minor performance boosts on <strong>trivially small random TSP instances</strong>, and towards <strong>embracing real-world benchmarks</strong> such as <a href="http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/">TSPLib</a> and <a href="http://vrp.atd-lab.inf.puc-rio.br/index.php/en/">CVPRLib</a>. Such routing problem collections contain graphs from cities and road networks around the globe along with their exact solutions, and have become the standard testbed for new solvers in the OR community.</p>
<p>At the same time, we must be vary to not ‘overfit’ on the top <code class="language-plaintext highlighter-rouge">n</code> TSPLib or CVPRLib instances that every other paper is using. Thus, better synthetic datasets go hand-in-hand for benchmarking progress fairly, e.g., <a href="https://openreview.net/forum?id=yHiMXKN6nTl">Queiroga et al., 2021</a> recently proposed a new libarary of synthetic 10,000 CVPR testing instances. Additionally, one can assess the robustness of neural solvers to small perturbations of problem instances with adversarial attacks, as proposed by <a href="https://arxiv.org/abs/2110.10942">Geisler et al., 2021</a>.</p>
<figure><center>
<img src="/images/routing-dl/ml4co.png" width="25%" />
<figcaption><b>Figure 11:</b> Community contests such as <a href="https://www.ecole.ai/2021/ml4co-competition/">ML4CO</a> are a great initiative to track progress. (Source: ML4CO website).</figcaption>
</center></figure>
<p><strong>Regular competitions</strong> on freshly curated real-world datasets, such as the <a href="https://arxiv.org/abs/2203.02433">ML4CO competition at NeurIPS 2021</a> and <a href="https://arxiv.org/abs/2201.10453">AI4TSP at IJCAI 2021</a>, are another great initiative to track progress in the intersection of deep learning and routing problems.</p>
<blockquote>
<p>We highly recommend the engaging panel discussion and talks from ML4CO, NeurIPS 2021, available on <a href="https://youtube.com/playlist?list=PLYWmzh0Y6EOZz3PtMxfaqEnRsfW-TF4nf">YouTube</a>.</p>
</blockquote>
<hr />
<h2 id="summary">Summary</h2>
<p>This blogpost presents a <strong>neural combinatorial optimization pipeline</strong> that unifies recent papers on deep learning for routing problems into a single framework. Through the lens of our framework, we then analyze and dissect recent advances, and speculate on directions for future research.</p>
<p>The following table highlights in <span style="color:red">Red</span> the major innovations and contributions for recent papers covered in the previous sections.</p>
<table>
<thead>
<tr>
<th>Paper</th>
<th>Definition</th>
<th>Graph Embedding</th>
<th>Solution Decoding</th>
<th>Solution Search</th>
<th>Policy Learning</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://arxiv.org/abs/2010.16011">Kwon et al., 2020</a></td>
<td>Full Graph</td>
<td>Transformer Encoder</td>
<td>Attention (AR)</td>
<td>Sampling</td>
<td><span style="color:red">POMO Rollout (RL)</span></td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2012.10658">Fu et al., 2020</a></td>
<td><span style="color:red">Sparse Sub-graphs</span></td>
<td>Residual Gated GCN</td>
<td>MLP Heatmap (NAR)</td>
<td><span style="color:red">MCTS</span></td>
<td>Immitation (SL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2102.11756">Kool et al., 2021</a></td>
<td>Sparse Graph</td>
<td>Residual Gated GCN</td>
<td>MLP Heatmap (NAR)</td>
<td><span style="color:red">Dynamic Programming</span></td>
<td>Immitation (SL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2110.03595">Ouyang et al., 2021</a></td>
<td>Full Graph + <span style="color:red">Data Augmentation</span></td>
<td><span style="color:red">Equivariant GNN</span></td>
<td>Attention (AR)</td>
<td>Sampling + Local Search</td>
<td><span style="color:red">Policy Rollout (RL)</span></td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/1912.05784">Wu et al., 2021</a></td>
<td>Sequence + <span style="color:red">Position</span></td>
<td>Transformer Encoder</td>
<td><span style="color:red">Transformer Decoder (L2I)</span></td>
<td>Local Search</td>
<td>Actor-critic (RL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2004.01608">da Costa et al., 2020</a></td>
<td>Sequence</td>
<td>GCN</td>
<td><span style="color:red">RNN + Attention (L2I)</span></td>
<td>Local Search</td>
<td>Actor-critic (RL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2110.02544">Ma et al., 2021</a></td>
<td>Sequence + <span style="color:red">Cyclic Position</span></td>
<td><span style="color:red">Dual Transformer Encoder</span></td>
<td><span style="color:red">Dual Transformer Decoder (L2I)</span></td>
<td>Local Search</td>
<td><span style="color:red">PPO + Curriculum (RL)</span></td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2110.07983">Xin et al., 2021</a></td>
<td>Sparse Graph</td>
<td>GAT</td>
<td>MLP Heatmap (NAR)</td>
<td><span style="color:red">LKH Algorithm</span></td>
<td>Immitation (SL)</td>
</tr>
<tr>
<td><a href="https://arxiv.org/abs/2110.05291">Hudson et al., 2021</a></td>
<td><span style="color:red">Sparse Dual Graph</span></td>
<td>GAT</td>
<td>MLP Heatmap (NAR)</td>
<td><span style="color:red">Guided Local Search</span></td>
<td>Immitation (SL)</td>
</tr>
</tbody>
</table>
<hr />
<p>As a final note, we would like to say that the <strong>more profound motivation</strong> of neural combinatorial optimization may NOT be to outperform classical approaches on well-studied routing problems. Neural networks may be used as a general tool for <strong>tackling previously un-encountered NP-hard problems</strong>, especially those that are non-trivial to design heuristics for. We are excited about recent applications of neural combinatorial optimization for <a href="https://www.nature.com/articles/s41586-021-03544-w">designing computer chips</a>, <a href="https://arxiv.org/abs/2109.10883">optimizing communication networks</a>, and <a href="https://openreview.net/forum?id=1QxveKM654">genome reconstruction</a>, and are looking forward to more in the future!</p>
<hr />
<p><strong>Acknowledgements</strong>: We would like to thank Goh Yong Liang, Yongdae Kwon, Yining Ma, Zhiguang Cao, Quentin Cappart, and Simon Geisler for helpful feedback and discussions.</p>Rishabh Anandmail.rishabh.anand@gmail.comThis blog post was written alongside Chaitanya Joshi, a good friend, senior-mentor, and current PhD student at Cambridge University. We are happy to announce this blog post was accepted (Top 50%) into the ICLR Blog Post Track 2022!!!Math Behind Graph Neural Networks2022-03-20T00:00:00+08:002022-03-20T00:00:00+08:00https://rish-16.github.io/posts/gnn-math<p><img src="/images/banner.png" width="100%" /></p>
<h2 id="foreword">Foreword</h2>
<blockquote>
<p>I’ve heard numerous requests to write something like this. The following blog post is my humble attempt at bridging the gaps in Graph Deep Learning. Don’t worry, I’ve added tons of diagrams and drawings to help visualise the whole thing! Also, I explicitly avoid the actual math-heavy concepts like spectral graph theory. Maybe an article for the future! The bulk of this article is comprehensive as long as you know the very basics of regular Machine Learning.</p>
</blockquote>
<hr />
<h2 id="preface">Preface</h2>
<p>Graph Deep Learning (GDL) <a href="https://twitter.com/prlz77/status/1178662575900368903">has picked up its pace over the years</a>. The natural network-like structure of many real-life problems makes GDL a versatile tool in the shed. The field has shown a lot of promise in social media, drug-discovery, chip placement, forecasting, bioinformatics, and more.</p>
<p>Here, I wish to provide a breakdown of popular Graph Neural Networks and their mathematical nuances – a very tiny survey, of sorts. Think of this as the continuation to my <a href="https://medium.com/dair-ai/an-illustrated-guide-to-graph-neural-networks-d5564a551783">previous article on Graph Neural Networks</a> that had no math at all.</p>
<p>⭐ The idea behind Graph Deep Learning is to learn the structural and spatial features over a graph with nodes and edges that represent entities and their interactions.</p>
<h3 id="structure-of-this-article">Structure of this article</h3>
<p>I start off by providing an in-depth breakdown of Graphs and Graph Neural Networks. Here, I deep dive into the granular steps one would take for the forward pass. Then, I move on to training these networks using familiar end-to-end techniques. Finally, I use the steps in the forward pass section as a framework or guideline to introduce popular Graph Neural Networks from the literature.</p>
<!-- - **TABLE OF CONTENTS**
1. **Representing Graphs**
1. Connection to Images
2. What's in a node?
3. Edges Matter Too!
2. **Graph Neural Network**
1. Message Passing
2. Aggregation
3. Update
4. Putting Them Together
5. Working with Edges
6. Working with Adjacency Matrices
7. Stacking GNN layers
3. **Training a GNN**
1. Training and Testing Graph Data (Transductive vs. Inductive)
2. Backprop and Gradient Descent
4. **Popular Graph Neural Networks**
1. Message Passing Neural Network
2. Graph Convolutional Network
3. Graph Attention Network
4. GraphSAGE
5. Temporal Graph Network
5. **Conclusion** -->
<ul id="markdown-toc">
<li><a href="#foreword" id="markdown-toc-foreword">Foreword</a></li>
<li><a href="#preface" id="markdown-toc-preface">Preface</a> <ul>
<li><a href="#structure-of-this-article" id="markdown-toc-structure-of-this-article">Structure of this article</a></li>
</ul>
</li>
<li><a href="#representing-graphs" id="markdown-toc-representing-graphs">Representing Graphs</a> <ul>
<li><a href="#connection-to-images" id="markdown-toc-connection-to-images">Connection to Images</a></li>
</ul>
</li>
<li><a href="#graph-neural-networks" id="markdown-toc-graph-neural-networks">Graph Neural Networks</a> <ul>
<li><a href="#whats-in-a-node" id="markdown-toc-whats-in-a-node">What’s in a Node?</a></li>
<li><a href="#edges-matter-too" id="markdown-toc-edges-matter-too">Edges Matter Too!!!</a></li>
</ul>
</li>
<li><a href="#message-passing" id="markdown-toc-message-passing">Message Passing</a> <ul>
<li><a href="#aggregation" id="markdown-toc-aggregation">Aggregation</a></li>
<li><a href="#update" id="markdown-toc-update">Update</a></li>
<li><a href="#putting-them-together" id="markdown-toc-putting-them-together">Putting Them Together</a></li>
<li><a href="#working-with-edge-features" id="markdown-toc-working-with-edge-features">Working with Edge Features</a></li>
<li><a href="#working-with-adjacency-matrices" id="markdown-toc-working-with-adjacency-matrices">Working with Adjacency Matrices</a></li>
</ul>
</li>
<li><a href="#stacking-gnn-layers" id="markdown-toc-stacking-gnn-layers">Stacking GNN layers</a></li>
<li><a href="#training-a-gnn-context-node-classification" id="markdown-toc-training-a-gnn-context-node-classification">Training a GNN (context: Node Classification)</a> <ul>
<li><a href="#training-and-testing-graph-data" id="markdown-toc-training-and-testing-graph-data">Training and Testing Graph Data</a></li>
<li><a href="#backprop-and-gradient-descent" id="markdown-toc-backprop-and-gradient-descent">Backprop and Gradient Descent</a></li>
</ul>
</li>
<li><a href="#popular-graph-neural-networks" id="markdown-toc-popular-graph-neural-networks">Popular Graph Neural Networks</a> <ul>
<li><a href="#message-passing-neural-network" id="markdown-toc-message-passing-neural-network">Message Passing Neural Network</a></li>
<li><a href="#graph-convolutional-network" id="markdown-toc-graph-convolutional-network">Graph Convolutional Network</a></li>
<li><a href="#graph-attention-network" id="markdown-toc-graph-attention-network">Graph Attention Network</a></li>
<li><a href="#graphsage" id="markdown-toc-graphsage">GraphSAGE</a></li>
<li><a href="#temporal-graph-network" id="markdown-toc-temporal-graph-network">Temporal Graph Network</a></li>
</ul>
</li>
<li><a href="#conclusion" id="markdown-toc-conclusion">Conclusion</a> <ul>
<li><a href="#call-to-action" id="markdown-toc-call-to-action">Call To Action</a></li>
<li><a href="#acknowledgements" id="markdown-toc-acknowledgements">Acknowledgements</a></li>
</ul>
</li>
</ul>
<hr />
<h2 id="representing-graphs">Representing Graphs</h2>
<p>Before we get into Graph Neural Networks, let’s explore what a graph is in Computer Science.</p>
<p>A graph \(\mathcal{G}(V, E)\) is a data structure containing a set of vertices (nodes) \(i \in V\)and a set of edges \(e_{ij} \in E\) connecting vertices \(i\) and \(j\). If two nodes \(i\) and \(j\) are connected, \(e_{ij} = 1\), and \(e_{ij} = 0\) otherwise. One can store this connection information in an <strong><em>Adjacency Matrix</em></strong> \(A\):</p>
<p><img src="/images/adjmat2.png" width="100%" /></p>
<p>⚠️ I assume the graphs in this article are <strong>unweighted</strong> (no edge weights or distances) and <strong>undirected</strong> (no direction of association between nodes). I assume these graphs are <strong>homogenous</strong> (single type of nodes and edges; opposite being “heterogenous”).</p>
<p>Graphs differ from regular data in that they have a structure that neural networks must respect; it’d be a waste not to make use of it. Here’s an example of a social media graph where nodes are users and edges are their interactions (like follow/like/retweet).</p>
<p><img src="/images/socialmedia.png" width="100%" /></p>
<p><a href="https://threatpost.com/researchers-graph-social-networks-spot-spammers-061711/75346/">Source</a></p>
<h3 id="connection-to-images">Connection to Images</h3>
<p>An image is a graph on its own! It’s a special variant called a “Grid Graph” where the number of outgoing edges from a node is constant for all internal and corner nodes. There’s some consistent structure present in the image grid graph that allows for simple Convolution-like operations to be performed on it.</p>
<p>An image can be considered a special graph where each pixel is a node and is connected to other pixels around it via imaginary edges. Of course, it’s impractical to view images in this light as that would mean having a very large graph. For instance, a simple CIFAR-10 image of \(32 \times 32 \times 3\) would have \(3072\) nodes and 1984 edges. For larger ImageNet images of \(224 \times 224 \times 3\), these numbers would blow up.</p>
<p><img src="/images/gridgraph.png" width="100%" /></p>
<p>An image can be considered a special graph where each pixel is a node and is connected to other pixels around it via imaginary edges. Of course, it’s impractical to view images in this light as that would mean having a very large graph. For instance, a simple CIFAR-10 image of \(32 \times 32 \times 3\) would have \(3072\) nodes and 1984 edges. For larger ImageNet images of \(224 \times 224 \times 3\), these numbers would blow up.</p>
<p>However, as you can observe, a graph isn’t that perfect. Different nodes have different degrees (number of connections to other nodes) and is all over the place. There is no fixed structure but the structure is what adds value to the graph. So, any neural network that learns on this graph must respect this structure while learning the spatial relationships between the nodes (and edges).</p>
<p>😌 As much as we want to use image processing techniques here, it’d be nice to have special graph-specific methods that are efficient and comprehensive for both small and large graphs.</p>
<hr />
<h2 id="graph-neural-networks">Graph Neural Networks</h2>
<p>A single Graph Neural Network (GNN) layer has a bunch of steps that’s performed on every node in the graph:</p>
<ol>
<li>Message Passing</li>
<li>Aggregation</li>
<li>Update</li>
</ol>
<p>Together, these form the building blocks that learn over graphs. Innovations in GDL mainly involve changes to these 3 steps.</p>
<h3 id="whats-in-a-node">What’s in a Node?</h3>
<p>Remember: a node represents an entity or object, like a user or atom. As such, this node has a bunch of properties characteristic to the entity being represented. These node properties form the features of a node (i.e., “node features” or “node embeddings”).</p>
<p>Typically, these features can be represented using vectors in \(\mathbb{R}^d\). This vector is either a latent-dimensional embedding or is constructed in a way where each entry is a different property of the entity.</p>
<p>🤔 For instance, in a social media graph, a user node has the properties of age, gender, political inclination, relationship status, etc. that can be represented numerically.</p>
<p>Likewise, in a molecule graph, an atom node might have chemical properties like affinity to water, forces, energies, etc. that can also be represented numerically.</p>
<p>These node features are the inputs to the GNN as we will see in the coming sections. Formally, every node \(i\) has associated node features \(x_i \in \mathbb{R}^d\) and labels \(y_i\) (that can either be continuous or discrete like <em><a href="https://en.wikipedia.org/wiki/One-hot">one-hot encodings</a></em>).</p>
<p><img src="/images/dataset.png" width="100%" /></p>
<h3 id="edges-matter-too">Edges Matter Too!!!</h3>
<p>Edges can have features \(a_{ij} \in \mathbb{R}^{d^\prime}\) as well, for instance, in cases where edges have meaning (like chemical bonds between atoms). We can think of the molecule shown below as a graph where atoms are nodes and bonds are edges.</p>
<p>While the atom nodes themselves have respective feature vectors, the edges can have different edge features that encode the different types of bonds (single, double, triple). Though, for the sake of simplicity, I’ll be omitting edge features in the following article.</p>
<p><img src="/images/molecule.png" width="100%" /></p>
<p>Now that we know how to represent nodes and edges in a graph, let’s start off with a simple graph with a bunch of nodes (with node features) and edges.</p>
<p><img src="/images/features.png" width="100%" /></p>
<h2 id="message-passing">Message Passing</h2>
<p>GNNs are known for their ability to learn structural information. Usually, nodes with similar features or properties are connected to each other (this is true in the social media setting). The GNN exploits this fact and learns how and why specific nodes connect to one other while some do not. To do so, the GNN looks at the Neighbourhoods of nodes.</p>
<blockquote>
<p>The <strong>Neighbourhood</strong> \(\mathcal{N}_i\) of a node \(i\) is defined as the set of nodes \(j\) connected to \(i\) by an edge. Formally, \(\mathcal{N}_i = \{j ~:~ e_{ij} \in E\}\).</p>
</blockquote>
<p><img src="/images/nbrhood.png" width="100%" /></p>
<p>A person is shaped by the circle he is in. Similarly, a GNN can learn a lot about a node \(i\) by looking at the nodes in its neighbourhood \(\mathcal{N}_i\). To enable this sharing of information between a source node \(i\) and its neighbours \(j\), GNNs engage in <strong>Message Passing</strong>.</p>
<blockquote>
<p>For a GNN layer, Message Passing is defined as the process of taking node features of the neighbours, transforming them, and “passing” them to the source node. This process is repeated, in parallel, for all nodes in the graph. In that way, all neighbourhoods are examined by the end of this step.</p>
</blockquote>
<p>Let’s zoom into node \(6\) and examine the neighbourhood \(\mathcal{N}_6 = \{1,~3,~4\}\). We take each of the node features \(x_1\), \(x_3\), and \(x_4\), and transform them using a function \(F\), which can be a simple neural network (MLP or RNN) or affine transform \(F(x_j) = \mathbf{W}_j \cdot x_j + b\). Simply put, a “message” is the transformed node feature coming in from source node.</p>
<p><img src="/images/messagepassing.png" width="100%" />
\(F\) can be a simple affine transform or neural network.</p>
<p>For now, let’s say \(F(x_j) = \mathbf{W}_j\cdot x_j\) for mathematical convenience. Here, \(\square \cdot \square\) represents simple matrix multiplication.</p>
<h3 id="aggregation">Aggregation</h3>
<p>Now that we have the transformed messages \(\{F(x_1), F(x_3), F(x_4)\}\) passed to node \(6\), we have to aggregate (“combine”) them someway. There are many things that can be done to combine them. Popular aggregation functions include,</p>
\[\begin{align} \text{Sum } &=\sum_{j \in \mathcal{N}_i} \mathbf{W}_j\cdot x_j \\ \text{Mean }&= \frac{\sum_{j \in \mathcal{N}_i} \mathbf{W}_j\cdot x_j}{|\mathcal{N}_i|} \\ \text{Max }&= \max_{j \in \mathcal{N}_i}(\{\mathbf{W}_j\cdot x_j\}) \\ \text{Min }&= \min_{j \in \mathcal{N}_i}(\{\mathbf{W}_j\cdot x_j\})\end{align}\]
<p>Suppose we use a function \(G\) to aggregate the neighbours’ messages (either using sum, mean, max, or min). The final aggregated messages can be denoted as follows:</p>
\[\bar{m}_i = G(\{\mathbf{W}_j \cdot x_j : j \in \mathcal{N}_i\})\]
<h3 id="update">Update</h3>
<p>Using these aggregated messages, the GNN layer now has to update the source node \(i\)’s features. At the end of this update step, the node should not only know about itself but its neighbours as well. This is ensured by taking the node \(i\)’s feature vector and combining it with the aggregated messages. Again, a simple addition or concatenation operation takes care of this.</p>
<p>Using addition:</p>
\[h_i = \sigma(K(H(x_i) + \bar{m}_i)))\]
<p>where \(\sigma\) is an activation function (ReLU, ELU, Tanh), \(H\) is a simple neural network (MLP) or affine transform, and \(K\) is another MLP to project the added vectors into another dimension.</p>
<p>Using concatenation:</p>
\[h_i = \sigma(K(H(x_i) ~\oplus~ \bar{m}_i)))\]
<p>To abstract this update step further, we can think of \(K\) as some projection function that transforms the messages and source node embedding together:</p>
\[h_i = \sigma(K(H(x_i),~ \bar{m}_i)))\]
<p>👉🏻 Notation-wise, the initial node features are called \(x_i\).</p>
<p>After a forward pass through the first GNN layer, we call the node features \(h_i\) instead. Suppose we have more GNN layers, we can denote the node features as \(h_i^l\) where \(l\) is the current GNN layer index. Also, it’s evident that \(h_i^0 = x_i\) (i.e., the input to the GNN).</p>
<h3 id="putting-them-together">Putting Them Together</h3>
<p>Now that we’ve gone through the Message Passing, Aggregation, and Update steps, let’s put them all together to formulate a single GNN layer on a single node \(i\):</p>
\[h_i = \sigma(W_1\cdot h_i + \sum_{j \in \mathcal{N}_i}\mathbf{W}_2\cdot h_j )\]
<p>Here, we use the <code class="language-plaintext highlighter-rouge">sum</code> aggregation and a simple feed-forward layer as functions \(F\) and \(H\).</p>
<p>⚠️ Do ensure that the dimensions of \(\mathbf{W}_1\) and \(\mathbf{W}_2\) commute properly with the node embeddings. If \(h_i \in \mathbb{R}^{d}\), \(\mathbf{W}_1, \mathbf{W}_2 \subseteq \mathbb{R}^{d^\prime \times d}\) where \(d^\prime\) is the embedding dimension.</p>
<h3 id="working-with-edge-features">Working with Edge Features</h3>
<p>When working with edge features, we’ll have to find a way to a GNN forward pass on them. Suppose edges have features \(a_{ij} \in \mathbb{R}^{d^\prime}\). To update them at a specific layer \(l\), we can factor in the embeddings of the nodes on either side of the edge. Formally,</p>
\[a^{l}_{ij} = T(h^l_i,~ h^l_j,~ a^{l-1}_{ij})\]
<p>where \(T\) is a simple neural network (MLP or RNN) that takes in the embeddings from connected nodes \(i\) and \(j\) as well as the previous layer’s edge embedding \(a^{l-1}_{ij}\).</p>
<h3 id="working-with-adjacency-matrices">Working with Adjacency Matrices</h3>
<p>So far, we looked at the entire GNN forward pass through the lense of a single node \(i\) in isolation and its neighbourhood \(\mathcal{N}_i\). However, it’s also important to know how to implement the GNN forward pass when given a whole adjacency matrix \(A\) and all \(N = \|V\|\) node features in \(X \subseteq \mathbb{R}^{N \times d}\).</p>
<p>In normal Machine Learning, in a MLP forward pass, we want to weight the items in the feature vector \(x_i\). This can be seen as the dot product of the node feature vector \(x_i \in \mathbb{R}^d\) and parameter matrix \(W \subseteq \mathbb{R}^{d^\prime \times d}\) where \(d^\prime\) is the embedding dimension:</p>
\[z_i = \mathbf{W} \cdot x_i ~~\in \mathbb{R}^{d^\prime}\]
<p>If we want to do this for all samples in the dataset (Vectorisation), we just matrix-multiply the parameter matrix and the features to get the transformed node features (messages):</p>
\[\mathbf{Z} = (\mathbf{W}\mathbf{X} )^T = \mathbf{X}\mathbf{W} ~~\subseteq \mathbb{R}^{N \times d^\prime}\]
<p>Now, in GNNs, for every node \(i\), a message aggregation operation involves taking neighbouring node feature vectors, transforming them, and adding them up (in the case of <code class="language-plaintext highlighter-rouge">sum</code> aggregation).</p>
<p>A single row \(A_i\) in the adjacency matrix tells us which nodes \(j\) are connected to \(i\). For every indiex \(j\) where \(A_{ij}=1\), we know node \(i\) and \(j\) are connected → \(e_{ij} \in E\).</p>
<p>For example, if \(A_2 = [1, 0, 1, 1, 0]\), we know that node \(2\) is connected to nodes \(1\), \(3\), and \(4\). So, when we multiply \(A_2\) with \(\mathbf{Z} = \mathbf{X}\mathbf{W}\) , we only consider the columns \(1\), \(3\), and \(4\) while ignoring columns \(2\) and \(5\). In terms of matrix multiplication, we are doing:</p>
<p><img src="/images/aggr.png" width="100%" /></p>
<p><img src="/images/matmul2.png" width="100%" /></p>
<p><img src="/images/matmul3.png" width="100%" />
Let’s focus on row 2 of \(A\).</p>
<p><img src="/images/matmul.png" width="100%" /></p>
<p>Matrix multiplication is simply the <strong>dot product</strong> of every row in \(A\) with every column in \(\mathbf{Z} = \mathbf{X}\mathbf{W}\)!!!</p>
<p>… and this is exactly what message aggregation is!!!</p>
<p>To get the aggregated messages for <em>all</em> \(N\) nodes in the graph based on their connections, we can matrix-multiply the entire adjacency matrix \(A\) with the transformed node features:</p>
\[\text{Y} = AZ = AXW\]
<p>‼️ <strong>A tiny problem:</strong> Observe that the aggregated messages do not factor in node \(i\)’s own feature vector (as we did above). To do that, we add self-loops to \(A\) (each node \(i\) is connected to itself).</p>
<p>This means changing the \(0\) to a \(1\) at every position \(A_{ii}\) (i.e., the diagonals).</p>
<p>With some linear algebra, we can do this using the Identity Matrix!</p>
\[\tilde{A} = A + I_{N}\]
<p><img src="/images/adjmat.png" width="100%" /></p>
<p>Adding self-loops allows the GNN to aggregate the source node’s features along with that of its neighbours!!</p>
<p>And with that, this is how you can do the GNN forward pass using matrices instead of single nodes.</p>
<p>⭐ To perform the <code class="language-plaintext highlighter-rouge">mean</code> aggregation, we can simple divide the sum by the count of \(1\)s in \(A_i\). For the example above, since there are three \(1\)s in \(A_2 = [1, 0, 0, 1, 1]\), we can divide \(\sum_{j \in \mathcal{N}_2}\mathbf{W}x_j\) by \(3\) … which is exactly the mean!!!</p>
<p>Though, It’s <em>not</em> possible to achieve <code class="language-plaintext highlighter-rouge">max</code> and <code class="language-plaintext highlighter-rouge">min</code> aggregation with the adjacency matrix formulation of GNNs.</p>
<h2 id="stacking-gnn-layers">Stacking GNN layers</h2>
<p>Now that we’ve figured out how single GNN layers work, how we build a whole “network” of these layers? How does information flow between the layers and how the GNN <em>refine</em> the embeddings/representations of the nodes (and/or edges)?</p>
<ol>
<li>The input to the first GNN layer is the node features \(X \subseteq \mathbb{R}^{N \times d}\). The output is the intermediate node embeddings \(H^1 \subseteq \mathbb{R}^{N \times d_1}\) where \(d_1\) is the first embedding dimension. \(H^1\) is made up of \(h^1_{i ~:~ 1 \rightarrow N} \in \mathbb{R}^{d_1}\).</li>
<li>\(H^1\) is the input to the second layer. The next output is \(H^2 \subseteq \mathbb{R}^{N \times d_2}\) where \(d_2\) is the second layer’s embedding dimension. Likewise, \(H^2\) is made up of \(h^2_{i ~:~ 1 \rightarrow N} \in \mathbb{R}^{d_2}\).</li>
<li>After a few layers, at the output layer \(L\), the output is \(H^L \subseteq \mathbb{R}^{N \times d_L}\). Finally, \(H^L\) is made up of \(h^L_{i ~:~ 1 \rightarrow N} \in \mathbb{R}^{d_L}\).</li>
</ol>
<p>The choice of \(\{d_1, d_2,\dots,d_L\}\) is completely up to us and are hyperparameters of the GNN. Think of these as choosing units (number of “neurons”) for a bunch of MLP layers.</p>
<p><img src="/images/fwdprop.png" width="100%" /></p>
<p>The node features/embeddings (“representations”) are passed through the GNN. The structure remains the same but the node representations are constantly changing through the layers. Optionally, your edge representations will also change but will no change connections or orientation.</p>
<p><strong>Now, there are a few things we can do with \(H^L\)</strong>:</p>
<ul>
<li>We can add it along the first axis (i.e., \(\sum_{k=1}^N h_k^L\)) to get a vector in \(\mathbb{R}^{d_L}\). This vector is the latest dimensional representation of the <em>whole</em> graph. It can be used for graph classification (eg: what molecule is this?).</li>
</ul>
<p><img src="/images/wholegraphclass.png" width="100%" /></p>
<ul>
<li>We can concatenate the vectors in \(H^L\) (i.e., \(\bigoplus_{k=1}^N h_k\) where \(\oplus\) is the vector concatenation operation) and pass it through a <a href="https://arxiv.org/abs/1611.07308">Graph Autoencoder</a>. This might help when the input graphs are noisy or corrupted and we want to reconstruct the denoised graph.</li>
</ul>
<p><img src="/images/gae.png" width="100%" /></p>
<ul>
<li>We can do <strong>node classification</strong> → what class does this node belong to?
The node embedding at a specific index \(h_i^L\) (\(i : 1 \rightarrow N\)) can be put through a classifier (like a MLP) into \(K\) classes (eg: is this a Carbon atom, Hydrogen atom, or Oxygen atom?).</li>
</ul>
<p><img src="/images/gnnclassifier.png" width="100%" /></p>
<ul>
<li>We can perform <strong>link prediction</strong> → should there be a link between some node \(i\) and \(j\)?
The node embeddings for \(h_i^L\) and \(h_j^L\) can be fed into another Sigmoid-based MLP that spits out a probability of an edge existing between those nodes.</li>
</ul>
<p><img src="/images/edgepred.png" width="100%" /></p>
<p>Either way, the fun thing is, each \(h_{1 \rightarrow N} \in H^L\) can be stacked and thought of as a <strong>batch</strong> of samples. One can easily treat it as a batch.</p>
<p>🚨 For a given node \(i\), the \(l^{\text{th}}\) layer in the GNN aggregates features the \(l\)-hop neighbourhood of node \(i\). Initially, the node sees its immediate neighbours and deeper into the network, it interacts with neighbours’ neighbours and so on.</p>
<p>This is why, for very small, sparse (very few edges) graphs, a large number of GNN layers often leads to a degradation in performance. This is because the node embeddings all converge to a singular vector as each node has seen nodes many hops away. This is a useless situation to be in!!!</p>
<p>Which explains why most GNN papers often use \(\leq4\) layers for their experiments to prevent the network from dying.</p>
<hr />
<h2 id="training-a-gnn-context-node-classification">Training a GNN (context: Node Classification)</h2>
<p>🥳 During training, the predictions for nodes, edges, or the whole graph can be compared with the ground-truth labels from the dataset using a loss function (eg: Cross Entropy).</p>
<p>This enables GNNs to be trained in an end-to-end manner using <strong>vanilla Backprop and Gradient Descent</strong>.</p>
<h3 id="training-and-testing-graph-data">Training and Testing Graph Data</h3>
<p>As with regular ML, graph data can be split into training and testing as well. This can be done in one of two ways:</p>
<p><strong>Transductive</strong></p>
<p>The training and testing data are both present in the same graph. The nodes from each set are connected to one another. It’s just that, during training, the labels for the testing nodes are hidden while the labels for training nodes are visible. However, the features of ALL nodes are visible to the GNN.</p>
<p>We can do this with a binary mask over all the nodes (if a training node \(i\) is connected to a testing node \(j\), just set \(A_{ij} = 0\) in the adjacency matrix).</p>
<p><img src="/images/transductive.png" width="100%" />
In the transductive setting, training and testing nodes are both part of the SAME graph. Just that training nodes expose their features and labels while testing nodes only expose their features. The testing labels are hidden from the model. <strong>Binary masks</strong> are needed to tell the GNN what’s a training node and what’s a testing node.</p>
<p><strong>Inductive</strong></p>
<p>Here, there are separate training and testing graphs that are hidden from one another. This is akin to regular ML where the model only sees the features and labels during training, and only the features for testing. Training and testing take place on two separate, isolated graphs. Sometimes, these testing graphs are out-of-distribution to check for quality of generalisation during training.</p>
<p><img src="/images/inductive.png" width="100%" /></p>
<p>Like regular ML, the training and testing data are kept separately. The GNN makes use of features and labels ONLY from the training nodes. There is no binary mask needed here to hide the testing nodes as they are from a different set.</p>
<h3 id="backprop-and-gradient-descent">Backprop and Gradient Descent</h3>
<p>During training, once we do the forward pass through the GNN, we get the final node representations \(h^L_i \in H^L\). To train the network in an end-to-end manner, we can do the following:</p>
<ol>
<li>Feed each \(h^L_i\) into a MLP classifier to get prediction \(\hat{y}_i\)</li>
<li>Calculate loss using ground-truth \(y_i\) and prediction \(\hat{y}_i\) → \(J(\hat{y}_i, y_i)\)</li>
<li>Use Backpropagatino to compute gradients \(\frac{\partial J}{\partial W^l}\) where \(W^l\) is the parameter matrix from layer \(l\)</li>
<li>Use some optimiser (like Gradient Descent) to update the parameters \(W^l\) for each layer in the GNN</li>
<li>(Optional) You can finetune the classifier (MLP) network’s weights as well.</li>
</ol>
<p><img src="/images/backprop.png" width="100%" /></p>
<p>🥳 This means GNNs are easily parallelisable both in terms of Message Passing <em>and</em> Training. The entire process can be vectorised (as shown above) and performed on GPUs!!!</p>
<hr />
<h2 id="popular-graph-neural-networks">Popular Graph Neural Networks</h2>
<p>In this section, I cover some popular works in the literature and categories their equations and math into the 3 GNN steps mentioned above (or at least I try). A lot of popular architectures merge the Message Passing and Aggregation steps into one function performed together, rather than one after the other explicitly. I try to decompose them in this section but for mathematical convenience, it’s best to see them as a singular operation!</p>
<blockquote>
<p>I’ve adapted the notation of the networks covered in this section to make it consistent with that of this article.</p>
</blockquote>
<h3 id="message-passing-neural-network">Message Passing Neural Network</h3>
<p><a href="https://arxiv.org/abs/1704.01212">Neural Message Passing for Quantum Chemistry</a></p>
<p>Message Passing Neural Networks (MPNN) decompose the forward pass into the <strong>Message Passing Phase</strong> with message function \(M_l\), and <strong>Readout Phase</strong> with vertex update function \(U_l\).</p>
<p>MPNN merges the Message Passing and Aggregation steps into the singular <strong>Message Passing Phase</strong>:</p>
\[m_i^{l+1} = \sum_{j \in \mathcal{N}_i} M_l(h_i^l,~ h_j^l,~ e_{ij})\]
<p>The <strong>Readout Phase</strong> is the update step:</p>
\[h_i^{l+1} = U_l(h_i^l,~ m_i^{l+1})\]
<p>where \(m_v^{l+1}\) is the <strong>aggregated message</strong> and \(h_v^{l+1}\) is the <strong>updated node embedding</strong>. This is very similar to the procedure I’ve mentioned above. The message function \(M_l\) is the mix of \(F\) and \(G\), and the function \(U_l\) is \(K\). Here, \(e_{ij}\) refers to possible edge features that can be omitted as well.</p>
<p><img src="/images/mpnn.png" width="100%" /></p>
<p>This paper uses MPNN as a general framework and formulate other works from the literature as special variations of a MPNN. The authors further use MPNN for quantum chemistry applications.</p>
<h3 id="graph-convolutional-network">Graph Convolutional Network</h3>
<p><a href="https://arxiv.org/abs/1609.02907">Semi-Supervised Classification with Graph Convolutional Networks</a></p>
<p>The Graph Convolutional Network (GCN) paper looks at the whole graph in its adjacency matrix form. First, self-connections are added to the adjacency matrix to ensure all nodes are connected to themselves to get \(\tilde{A}\). This ensures we factor in the source node’s embeddings during Message Aggregation. The combined Message Aggregation and Update steps look like so:</p>
\[H^{l+1} = \sigma(\tilde{A}H^lW^l)\]
<p>where \(W^l\) is a learnable parameter matrix. Of course, I change \(X\) to \(H\) to generalise the node features at ay arbitrary layer \(l\) where \(H^0 = X\).</p>
<p>🤔 Due to the associative property of matrix multiplication (\(A(BC) = (AB)C\)), it doesn’t matter which sequence we mutiply the matrices in (either \(\tilde{A}H^l\) first, post-multiply \(W^l\) next <strong>OR</strong> \(H^lW^l\) first, pre-multiply \(\tilde{A}\) next).</p>
<p>However, the authors, Kipf and Welling, further introduce a degree matrix \(\tilde{D}\) as a form of renormalisation to avoid numerical instabilities and exploding/vanishing gradients:</p>
\[\tilde{D}_{ii} = \sum_j \tilde{A}_{ij}\]
<p>The “renormalisation” is carried out on the augmented adjacency matrix \(\hat{A} = \tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}\). Altogether, the new combined Message Passing and Update steps look like so:</p>
\[H^{l+1} = \sigma(\hat{A}H^lW^l)\]
<h3 id="graph-attention-network">Graph Attention Network</h3>
<p><a href="https://arxiv.org/abs/1710.10903">Graph Attention Networks</a></p>
<p>Aggregation typically involves treating all neighbours <strong>equally</strong> in the sum, mean, max, and min settings. However, in most situations, some neighbours are more important than others. Graph Attention Networks (GAT) ensure this by weighting the edges between a source node and its neighbours using of <em>Self-Attention</em> by <a href="https://arxiv.org/abs/1706.03762">Vaswani et al. (2017)</a>.</p>
<p>Edge weights \(\alpha_{ij}\) are generated as follows.</p>
\[\alpha_{ij} = \text{Softmax}(\text{LeakyReLU}(\mathbf{W_a}^T \cdot [\mathbf{W}h^l_i ~\oplus~ \mathbf{W}h^l_j]))\]
<p>where \(\mathbf{W_a} \in \mathbb{R}^{2d^\prime}\) and \(\mathbf{W} \subseteq \mathbb{R}^{d^\prime \times d}\) are learned parameters, \(d^\prime\) is the embedding dimension, and \(\oplus\) is the vector concatenation operation.</p>
<p>While the initial Message Passing step remains the same as MPNN/GCN, the combined Message Aggregation and Update steps are a weighted sum over all the neighbours and the node itself:</p>
\[h_i = \sum_{j \in \mathcal{N}_i~\cup ~\{i\}}\alpha_{ij} ~\cdot~ \mathbf{W}h_j^l\]
<p><img src="/images/gat.png" width="100%" /></p>
<p>Edge Importance Weighting helps understand how much neighbours affects a source node.</p>
<p>As with the GCN, self-loops are added so source nodes can factor in their own representations for future representations.</p>
<h3 id="graphsage">GraphSAGE</h3>
<p><a href="https://arxiv.org/abs/1706.02216">Inductive Representation Learning on Large Graphs</a></p>
<p>GraphSAGE stands for Graph <strong>SA</strong>mple and Aggre<strong>G</strong>at<strong>E</strong>. It’s a model to generate node embeddings for large, very dense graphs (to be used at companies like Pinterest).</p>
<p>The work introduces learned aggregators on a node’s neighbourhoods. Unlike traditional GATs or GCNs that consider all nodes in the neighbourhood, GraphSAGE uniformly samples the neighbours and uses the learned aggregators on them.</p>
<p>Suppose we have \(L\) layers in the network (depth), each layer \(l \in \{1,\dots,L\}\) looks at a larger \(l\)-hop neighbourhood w.r.t. the source node (as one would expect). Each source node is then updated by concatenating the node embedding with the sampled messages before being passed through a MLP \(F\) and non-linearity \(\sigma\).</p>
<p>For a certain layer \(l\),</p>
\[h^l_{\mathcal{N}(i)} = \text{AGGREGATE}_{l}(\{h^{l-1}_j : j \in \mathcal{N}(i)\}) \\ h^l_i = \sigma(F(h^{l-1}_i ~\oplus~ h^l_{\mathcal{N}(i)}))\]
<p>where \(\oplus\) is the vector concatenation operation, and \(\mathcal{N}(i)\) is the uniform sampling function that returns a subset of all neighbours. So, if a node has 5 neighbours \(\{1,2,3,4,5\}\), possible outputs from \(\mathcal{N}(i)\) would be \(\{1,4,5\}\) or \(\{2,5\}\).</p>
<p><img src="/images/graphsage.png" width="100%" /></p>
<p>Aggregator \(k = 1\) aggregates sampled nodes (coloured) from the \(1\)-hop neighbourhood while Aggregator \(k = 2\) aggregates sampled nodes (coloured) from the \(2\)-hop neighbourhood</p>
<p>Possible future work could be experimenting with non-uniform sampling functions to choose neighbours.</p>
<blockquote>
<p><strong>Note:</strong> In the paper, the authors use \(K\) and \(k\) to denote the layer index. In this article, I use \(L\) and \(l\) respectively to stay consistent. Furthermore, the paper uses \(v\) to denomte source node \(i\) and \(u\) to denote neighbour \(j\).</p>
</blockquote>
<p><strong>Bonus:</strong> Prior work to GraphSAGE includes <a href="https://arxiv.org/abs/1403.6652"><strong>DeepWalk</strong></a>. Check it out!</p>
<h3 id="temporal-graph-network">Temporal Graph Network</h3>
<p><a href="https://arxiv.org/abs/2006.10637">Temporal Graph Networks for Deep Learning on Dynamic Graphs</a></p>
<p>The networks described so far work on static graphs. Most real-life situations work on dynamic graphs where nodes and edges are added, deleted, or updated over a duration of time. The Temporal Graph Network (TGN) has works on continuous time dynamic graphs (CTDG) that can be represented as a chronologically sorted list of events.</p>
<p>The paper breaks down events into two types: <strong>node-level events</strong> and <strong>interaction events</strong>. Node-level events involve a node in isolation (eg: a user updates their profile’s bio) while interaction events involve two nodes that may or may not be connected (eg: user A retweets/follows user B).</p>
<p>TGN offers a modular approach to CTDG processing with the following components:</p>
<ol>
<li>
<p><strong>Message Passing Function</strong> → message passing between isolated nodes or interacting nodes (for either type of event).</p>
</li>
<li>
<p><strong>Message Aggregation Function</strong> → <em>**</em>uses the GAT’s aggregation by looking at a <em>temporal neighbourhood</em> through many timesteps instead of a local neighbourhood at a given timestep.</p>
</li>
<li>
<p><strong>Memory Updater</strong> → memory allows the nodes to have long-term dependencies and represents the history of the node in latent (“compressed”) space. This module updates the node’s memory based on the interactions taking place through time.</p>
</li>
<li>
<p><strong>Temporal Embedding</strong> → a way to represent the nodes that capture the essence of time as well.</p>
</li>
<li>
<p><strong>Link prediction</strong> → the temporal embeddings of the nodes involves in an event are fed through some neural network to calculate edge probabilities (i.e., will the edge occur in the future?).
Of course, during training, we know the edge exists so the edge label is \(1\). We need to train the Sigmoid-based network to predict this as usual.</p>
</li>
</ol>
<p><img src="/images/tgn1.png" width="100%" /></p>
<p>Every time a node is involved in an activity (node update or inter-node interaction), the memory is updated.</p>
<p><strong>(1)</strong> For each event \(1\) and \(2\) in the batch, TGN generates messages for all nodes involved that event.</p>
<p><strong>(2)</strong> Next, for TGN aggregates the messages of each node \(m_i\) for all timesteps \(t\); this is called the temporal neighbourhood of the node \(i\).</p>
<p><strong>(3)</strong> Next, TGN uses the aggregated messages \(\bar{m}_i(t)\) to update the memory of each node \(s_i(t)\).</p>
<p><img src="/images/tgn2.png" width="100%" /></p>
<p><strong>(4)</strong> Once the memory \(s_i(t)\) is up to date for all nodes, it’s used to compute “temporal node embeddings” \(z_i(t)\) for all nodes used in the specific interactions in the batch.</p>
<p><strong>(5)</strong> These node embeddings are then fed into a MLP or neural network to get the probabilities of each the events taking place (using <em>Sigmoid</em> activation).</p>
<p><strong>(6)</strong> We can then compute the loss using Binary Cross Entropy (BCE) as usual (not shown).</p>
<p>For more on the TGN, check out my short paper review here: <a href="https://www.notion.so/Temporal-Graph-Networks-for-Deep-Learning-on-Dynamic-Graphs-9da6cdd5ff2948d6882f6367106d4bff"><strong>Temporal Graph Networks for Deep Learning on Dynamic Graphs</strong></a></p>
<p>The authors have also written a blogpost on the TGN that can be found here:</p>
<p><a href="https://towardsdatascience.com/temporal-graph-networks-ab8f327f2efe">Temporal Graph Networks</a></p>
<hr />
<h2 id="conclusion">Conclusion</h2>
<p>Graph Deep Learning is a great toolset when working with problems that have a network-like structure. They are simple to understand and implement using libraries like <code class="language-plaintext highlighter-rouge">PyTorch Geometric</code>, <code class="language-plaintext highlighter-rouge">Spektral</code>, <code class="language-plaintext highlighter-rouge">Deep Graph Library</code>, <code class="language-plaintext highlighter-rouge">Jraph</code> (if you use <code class="language-plaintext highlighter-rouge">jax</code>), and now, the recently-released <code class="language-plaintext highlighter-rouge">TensorFlow-gnn</code>. GDL has shown promise and will continue to grow as a field. In fact, most popular GDL papers come with a codebase written in either PyTorch or TensorFlow, so it helps a lot with experimentation.</p>
<p>Fun fact: GDL now falls under the umbrella of Geometric Deep Learning (“the new GDL”) that learns structural and spatial inductive biases on geometric surfaces like manifolds, graphs, and miscellaneous topologies. There are many academics who now specialise in Geometric DL with lots of exciting works coming out every month. In fact, I’d recommend taking a look at the dedicated Geometric Deep Learning <a href="https://geometricdeeplearning.com/lectures/">course</a> that features material by authors, Michael M. Bronstein, Joan Bruna, Taco Cohen, and Petar Veličković.</p>
<p>Till then, I’ll see you in the next post! Happy reading 😊</p>
<hr />
<p>🙏🏻 If you want me to cover any of these papers or methods in more detail, email me at <a href="mailto:mail.rishabh.anand@gmail.com">mail.rishabh.anand@gmail.com</a> or DM me on Twitter at <a href="http://twitter.com/rishabh16_">@rishabha16_</a>.</p>
<h3 id="call-to-action">Call To Action</h3>
<p>If you like what you just read, there’s plenty more where that came from! Subscribe to my email newsletter / tech blog on Machine Learning and Technology:</p>
<p><a href="http://rishtech.substack.com/">RishTech</a></p>
<h3 id="acknowledgements">Acknowledgements</h3>
<p>I’d like to express my gratitude to Alex Foo and Chaitanya Joshi for the useful feedback and comments on this article. Our fruitful conversations and discussions shaped much of the work surrounding the GNN equations and their ease of readability. Thank you to Petar Veličković for his comments post-publication!</p>Rishabh Anandmail.rishabh.anand@gmail.com[WIP] Elo Loss: From Chess Ratings to Objective Functions2022-01-30T00:00:00+08:002022-01-30T00:00:00+08:00https://rish-16.github.io/posts/2022/01/elo-loss<h2 id="preface">Preface</h2>
<p>I rewatched the DeepMind documentary on AlphaGo recently on whim. They mentioned the <a href="https://en.wikipedia.org/wiki/Elo_rating_system">Elo Rating</a> used to rate players in competitive games like Chess, Go, Baseball, general boardgames, and more. This suddenly got me curious about the scoring system and I read up quite a bit about it.</p>
<p>In this post, I document my rather crazy idea of repurposing the Elo Rating formula into a loss function to train networks and somehow, it worked out pretty well. I describe my experimental setup and showcase some interesting results.</p>
<h2 id="elo-rating">Elo Rating</h2>
<p>Elo rating is used to calculate relative skills levels between players. Players in such games all start off with a base rating and move up the ladder based on calculations involving the Elo rating. It looks rather convoluted and a quick internet search doesn’t give much information on it either. Let me quickly break it down for you here. For starters, given two players <strong>A</strong> and <strong>B</strong>, the formula looks like this:</p>
<p><img src="/images/elo.jpg" width="100%" /></p>
<p>\(E_A\) denotes the expected probability of <strong>A</strong> winning the game giving \(R_A\) and \(R_B\), the ratings of players <strong>A</strong> and <strong>B</strong> respectively; ultimately, we’re calculating \(P(\text{A wins})\). You can do this for player <strong>B</strong> too: in the denominator, you have to switch it to \(R_A - R_B\) instead. This function always gives a probability value \(\leq 1\).</p>
<p>To update the score for player <strong>X</strong>, the following formula is used:</p>
<p><img src="/images/update.jpg" width="100%" /></p>
<p>where \(K\) is some constant usually set by tournament’s presiding members and \(S_X=\{1, 0, \frac{1}{2}\}\) is the result of the game (\(1\) means win, \(0\) means lose, \(\frac{1}{2}\) is draw). In fact, if you realise, the lower the expected score \(E_X\), the higher the “surprise” of player <strong>X</strong> winning the game (i.e., \(S_X=1\)), resulting in a proportionally higher score hike for <strong>X</strong>. Likewise, if it’s rather “duh” the player <strong>X</strong> will highly likely win (i.e., a high #\(E_X\)), the lower the score hike for them.</p>
<blockquote>
<p>Of course, this update step is not of concern in this blog post. It’s just a funky detail I decided to add for brevity.</p>
</blockquote>
<h2 id="using-elo-as-a-loss-function">Using Elo as a loss function</h2>
<p>The Elo score function \(E_X\) for some player <strong>X</strong> is an example of a logistic function that has the following general form:</p>
<p><img src="/images/logistic.jpg" width="100%" /></p>
<p>In fact, when \(L=1\), \(k=1\), and \(x_0 = 1\), it’s called a <em>Sigmoid</em> i.e., the activation function used to inject non-linearity into perceptrons and neural networks; it’s mainly used in binary classification problems given some classification threshold. Here, however, I use this Elo logistic function as a loss function, not an activation. It’s best to make that distinction before showcasing some interesting results.</p>
<blockquote>
<p>An interesting project for the future would be to evaluate different logistic losses while playing around with values of \(L\), \(k\), and \(x_0\).</p>
</blockquote>
<p>This so-called <strong>Elo Loss</strong> can be formulated as follows:</p>
<p><img src="/images/eloloss.jpg" width="100%" /></p>
<p>where \(\hat{y}\) is the prediction, \(y\) is the label, and \(m\) is the batch/minibatch size. The \(\hat{y}\) and \(y\) terms \(\in \mathbb{R}^d\) can either be tensors or scalars. This can be considered in the batch/minibatch setting as well by performing this computation over all samples in question.</p>
<h2 id="results-and-discussion">Results and Discussion</h2>
<p>I trained a standard MLP using Elo Loss and standard Cross Entropy Loss on MNIST and CIFAR-10. I repeated the experiments with a CNN using a similar set-up but with the additional Fashion MNIST dataset included. I followed a liberal training regime. For some NLP tasks, I looked at sentiment analysis and trained LSTMs using Cross Entropy and Elo Loss.</p>
<p>I had a conversation with Lucas Beyer from Google AI a while back (<a href="">thread</a>). When comparing models across \(N\) runs, do not average the accuracies or seed the models. While the intention to maintain fairness is valid, there are better ways of going about it. For that very purpose, I’ve included multiple plots on the same graph to show some “in the wild” performance. Other than loss, I’ve kept the model architecture, optimiser, and miscellaneous hyper-parameters (learning rate, momentum, etc.) constant for all runs. I’ve been told this is a better way of reporting “fair” results.</p>
<h3 id="mlp-on-mnist">MLP on MNIST</h3>
<p><img src="/images/mnist_elo.jpg" width="100%" /></p>
<h3 id="cnn-on-cifar-10">CNN on CIFAR-10</h3>
<h3 id="cnn-on-fashion-mnist">CNN on Fashion MNIST</h3>
<h3 id="lstm-on-sentiment-analysis">LSTM on Sentiment Analysis</h3>
<h3 id="actor-critic-on-atari">Actor-Critic on ATARI</h3>
<p>Here, I picked 3 games: <code class="language-plaintext highlighter-rouge">Breakout-v0</code>, <code class="language-plaintext highlighter-rouge">Pong-v0</code>, and <code class="language-plaintext highlighter-rouge">CartPole-v1</code>.</p>
<h2 id="final-remarks">Final Remarks</h2>
<p>Here, I use the Elo rating system as a loss function. I’m not entirely sure if this has been done before based on a quick search online. I’m guessing any decent differentiable function can be used as a loss. This gets me thinking about the desirable characteristics of loss functions in general, aside from differentiability.</p>Rishabh Anandmail.rishabh.anand@gmail.comPrefaceMediums of Computational Communication2022-01-24T00:00:00+08:002022-01-24T00:00:00+08:00https://rish-16.github.io/posts/2022/01/languages<h2 id="preface">Preface</h2>
<p>I’m currently taking a class on <a href="https://nusmods.com/modules/UIT2201/computer-science-the-i-t-revolution">computers and information technology</a> and recently covered machine languages and their evolution. An ongoing discussion at the time was the variou ways machine languages differed from human languages. A lot of points were brought up about the flexibility of human languages (HLs), the rigidity of machine languages (MLs), the use of synonyms, references, metaphors, etc. in HLS, and others.</p>
<p>In this post, I wish to highlight three more differences I think are very cool and do a fine job at differentiating the two. They are, to some extent, inspired by the book I’m reading titled “Thinking Forth” by <em>Leo Brodie</em>. FYI, <code class="language-plaintext highlighter-rouge">Forth</code> is a popular programming language created by <em>Charles H. Moore</em> and is claimed to have been designed with a lot of forethought and reflection on what languages should be like. MLs are what the computer understands or is provided with; this includes low-level languages like <code class="language-plaintext highlighter-rouge">Assembly</code> and high-level languages like <code class="language-plaintext highlighter-rouge">Python</code> and <code class="language-plaintext highlighter-rouge">Golang</code>. HLs include the language this blogpost was written in (English!) and other languages spoken around the world.</p>
<h2 id="the-differences">The Differences</h2>
<h3 id="order-of-formalisation"><strong>Order of Formalisation</strong></h3>
<p>For HLs, the language is created and used first, then the grammar and syntax are invented. A bunch of people (i.e., linguists) come together and formalise a bunch of “rules” that help bucket/quantify the different unit elements of language (eg: words, phrases, the parts of speech, etc.). For MLs, the grammar, semantics, and syntax are devised first, and then the language is built on top of this schematic. Machine language is very derivative in that sense.</p>
<p><img src="/images/order.jpg" width="100%" /></p>
<h3 id="presence-of-design-principles"><strong>Presence of Design Principles</strong></h3>
<p>MLs have design principles that guide their construction:</p>
<p><img src="/images/desprinci.jpg" width="100%" /></p>
<p>However, HLs do not have such things attached to their creation. No one sat down one day and said “I’m going to create a schematic for a human language that’s so good, everyone’s gonna use it!!!”. It develops naturally the longer a diverse set of people use it. Furthermore, back in the day, languages evolved greatly when they were taken out of the location of origin and spread far and wide. Locals would add their. This isn’t widely seen in MLs – it stays the same no matter where in the world it’s used. Guess it automatically solves the issue of “but it works on my computer!”. This goes back to MLs’ rigidity when it comes to their expression and usage.</p>
<h3 id="generalisability-to-higher-levels"><strong>Generalisability to Higher-levels</strong></h3>
<p>In the book, Brodie talks about how <code class="language-plaintext highlighter-rouge">Forth</code> was built using the design principles listed above. Conveniently, these design principles can be generalised to the construction of computer software in general ie. software can be modular, readable, writeable, manageable, or ripe with abstraction. It’s a whole concept of the “creation” (the software built) being constructed using similar principles or considerations as the “medium of creation” (the ML or programming language used).</p>
<p><img src="/images/pl.jpg" width="100%" /></p>
<p>This doesn’t necessarily work with HLs – the “design principles” of human language cannot be generalised to something higher, broader, or more abstract. One can’t exactly categorise or define these design principles vividly either. As in, what makes one human language “better” than the others? That question is easily answerable when it comes to machine languages (eg: Python > Assembly because it’s more readable and writeable). Also, what constitutes as the creations? Poems, stories, literature in general? In that case, while a poem sounds comprehensive, the same can’t be said about the language it is written in.</p>
<h2 id="in-a-nutshell">In a nutshell</h2>
<p>Languages are fascinating: they have evolved from being simple media of communication to running the world of politics, economics, stock markets, and more. Words are charged with meaning and should never be taken at face value, <em>especially</em> not online. Machine Languages have changed over the decades because humans crave for less-rigid ways of sharing ideas and concepts.</p>
<blockquote>
<p>“Programming is a medium of communication – not between man and machine, but amongst man himself.” ~ my freshman “Intro to Programming” Professor at NUS</p>
</blockquote>
<p>I look forward to how communication with computers changes over the next few years. The improvement in SOTA NLP models suggests the possible usage of natural language to instruct computers on what to do. However, I feel it’s more worthwhile to teach computers to fully understand the nuance of what these commands mean instead of simply translating between natural language and machine code/language (brings the whole <a href="https://plato.stanford.edu/entries/chinese-room/">Chinese Room Argument</a> into picture).</p>
<p><em>Thank you to my classmate, Rachelle, for her 5-minute crash course on the history and evolution of language over text!!</em></p>Rishabh Anandmail.rishabh.anand@gmail.comPrefaceRead this!2022-01-16T00:00:00+08:002022-01-16T00:00:00+08:00https://rish-16.github.io/posts/2022/01/intro-post<p>Hello!</p>
<p>I’ll be gradually updating this website with more of my past articles. I’m in the process of moving from Substack to this personal site instead.</p>
<p>In the meantime, you can check out my <a href="http://rishtech.substack.com">Substack blog here</a>.</p>
<p>Cheers!</p>Rishabh Anandmail.rishabh.anand@gmail.comHello!