Understanding graph neural networks by way of convolutional nets
Augusto Stoffel (PhD)
In this article, we will introduce the basic ideas behind graph neural networks (GNNs) through an analogy with convolutional neural networks (CNNs), which are very well known due to their prevalence in the field of computer vision. In fact, we'll see that convolutional nets are an example of GNNs, albeit one where the underlying graph is very simple, perhaps even boring. Once we see how to think of a convolutional net through this lens, it won't be hard to replace that boring graph with more interesting ones, and we'll arrive naturally at the general concept of GNN.
After that, we will survey some applications of GNNs, including our use here at dida.
But let's start with the basics.
Graphs
A graph is simply a set of so-called vertices or nodes, which we typically represent as dots on the page, together with a set of pairs of vertices, called edges or links, which we depict as arrows connecting the vertices. (The precise definitions and terminology vary quite a bit, but this will not cause confusion here). It's also very common to attach labels to the vertices and edges.
Graphs are a fundamental concept in computer science and combinatorics, but they are also widely used in practical applications to represent any kind of network of entities, such as
social networks, where nodes represent people and edges can represent a friendship relationship;
webpages and hyperlinks between them, or academic papers and citations;
constituent steps of a complex process and dependency relationships between them.
A nice introductory book about graphs and network science is “Linked:The New Science of Networks” by Albert-László Barabási.
Graph theory meets machine learning
Graph neural networks are a category of machine learning models which are tailor-made to learn from data structured in graph form.
While it is often possible to massage a graph into a form that can serve as input to other ML methods, GNNs have the advantage of working directly with graph data, and, therefore, a higher potential to learn fine details from and about the topology of the network.
Analogy with convolutional nets
To understand how GNNs work, and be able to appreciate some of the finer details of the story, let's make an analogy with the popular convolutional networks, which are ubiquitous in computer vision tasks. A single convolutional layer is simply a procedure that computes, for each pixel $$p$$ of an image, an output, or label, $$\mathrm{out}(p)$$ based on inputs, or attributes, $$\mathrm{in}(p)$$ of the pixel in question and of its immediate neighbors.
Now, let's look at this single convolutional layer in a somewhat strange way. Let's think of a small 3 × 3 grid in our image as a little graph; each pixel is represented by a vertex, and the fact that this pixel has 8 immediate neighbors is represented by edges of the graph.
The entire image would then be represented by a huge graph where each node has 8 incoming edges and 8 outgoing edges, each of them labeled by a cardinal direction. (We will not worry about pixels near the boundary of the image, for which the situation is slightly different, of course.)
With this in place, we can now think of a convolutional layer as a procedure as follows:
Each node of our graph (representing a pixel) starts with some attributes. To fix the ideas, let's assume this is the first layer in a deeper neural network and the attributes of the node are just the RGB values of the pixel.
Each node receives a “message” from the nodes connected to it (which in this case are just the neighboring pixels), for a total of 8 messages. The content of the message, again, is just the RGB value of the pixel represented by the origin of the edge in question — the pixel doesn't “know” any other information it could pass along, anyway.
For each node, a fixed formula is used to compute an output value in terms of the attributes of the node and the content of the messages it received.
Let's clarify the role of edge labels in the above. You can think of them as the “envelope” of the message. When a pixel receives a message, it learns about the RGB value of a neighbor in a specific direction. If we didn't label the edges with the cardinal directions, the destination node would only learn that some neighbor has the passed RGB value. As a result, the output of a given pixel could only be computed in terms of the (unordered) set of RGB values of its neighbors neighbors—which, as we know, is not how CNNs work.
A formula for the convolutional layer, written in graphy flavor, is as follows:
Here, $$p$$ and $$q$$ denote nodes, $$D \in \{N, S, W, E \dots \}$$ denotes one of the 8 cardinal directions, $$w_D$$ is a weight of the convolutional layer (associated to the direction $$D$$), and the subscript of the summation sign indicates that the sum is over all nodes $$q$$ connected to $$p$$, with $$D$$ denoting the label of the edge connecting them. There's also an additional weight $$w_O$$ associated to the input value of $$p$$ itself. Finally, $$\sigma$$ denotes an activation function.
Graph neural networks
A graph neural network is a generalization of the idea of the subsection above, where we simply remove several of the special features of the image example:
A node doesn't need to have a fixed number of incoming or outgoing edges
There can be several edges of a fixed kind (or none at all) reaching a given node.
Some GNN architectures allow attaching attributes to edges as well as vertices and placing labels on the nodes.
A formula for a basic GNN layer would be
where, besides the notation used in the previous formula, we also included a normalization factor computed in terms of the degrees $$d_p$$, $$d_q$$ of the nodes, that is, the total number of incident edges. We didn't explicitly include the attributes $$\mathrm{in}(p)$$ of the node $$p$$ itself in the formula for $$\mathrm{out}(p)$$, but this can be done by adding special edges, say with label $$O$$, connecting each node to itself.
This, of course, is just a first stab at the definition of a graph layer. Many variations and refinements are possible, and there also exist graph versions of recurrent neural networks, gated recurrent units, and self-attention layers. Note also that we have only discussed a single GNN layer here. As with convolutional nets, and deep learning in general, the magic starts to happen when we stack several of those layers.
Ready to use implementations of various GNN layers can be found in libraries such as PyTorch Geometric package, DGL, and Spektral.
Applications
Classes of problems that can be addressed with GNNs include:
Node label prediction: The goal here is to make a prediction about each node of a graph. For instance, we may wish to decide which nodes of a transportation network are the main hubs.
Graph classification: The goal here is to attach a label to an entire graph. For instance, given the graph of a molecule, representing its constituent atoms and bonds between them, we might want to predict molecular properties. Another example: classify industrial components into different categories from information about its subcomponents and the way those are assembled together.
Link prediction: This refers to finding potential or missed edges in a graph. Examples include friend recommendation in social networks, protein-protein interaction prediction, and knowledge graph completion.
At dida, we have successfully employed graph neural networks in NLP tasks where the physical disposition of words on the page is relevant, such as table recognition and understanding. Here, a page of a document is described by a graph whose nodes are words, lines or columns of text, etc. Attributes of the nodes can include both information about its textual content and information about its positioning. Finally, edges of the graph represent various relationships between these page elements, such as two words being neighbors or a word belonging to a row or column.
In comparison with more traditional ML methods, we have found deep GNNs to deliver excellent results in table understanding tasks, without the need of extensive feature engineering or ad-hoc table segmentation algorithms. This is specially important when dealing with heterogeneous collections of documents with lots of variability in their formatting.