Posted by Dustin Zelle, Software Engineer, Google Research, and Arno Eigenwillig, Software Engineer, CoreML
Objects and their relationships are everywhere in our world, and understanding an object often requires considering its relationships with other objects. This applies to transportation networks, production networks, knowledge graphs, and social networks, among others. Graphs, which consist of nodes connected by edges in various ways, have long been used in discrete mathematics and computer science to formalize such relationships. However, most machine learning algorithms are limited to regular and uniform relations between input objects, such as grids of pixels or sequences of words.
Graph neural networks (GNNs) have emerged as a powerful technique for leveraging the connectivity of graphs and the features of nodes and edges. GNNs can make predictions for entire graphs, individual nodes, or potential edges. They are also useful for encoding a graph’s relational information in a continuous way that can be incorporated into other deep learning systems.
We are excited to announce the release of TensorFlow GNN 1.0 (TF-GNN), a production-tested library for building GNNs at large scales. TF-GNN supports both modeling and training in TensorFlow, as well as the extraction of input graphs from large data stores. It is designed specifically for heterogeneous graphs, where different types of objects and relations are represented by distinct sets of nodes and edges. Within TensorFlow, these graphs are represented by objects of type tfgnn.GraphTensor, which is a composite tensor type that stores both the graph structure and its features.
TF-GNN provides a flexible Python API for configuring dynamic or batch subgraph sampling, which is crucial for GNN training. It supports interactive sampling in a Colab notebook, efficient sampling of small datasets stored in memory, as well as distributed sampling using Apache Beam for large datasets stored on a network filesystem.
The training process of a GNN involves computing a hidden state at the root node by aggregating and encoding the relevant information from its neighborhood. This is typically done using message-passing neural networks, where nodes receive messages from their neighbors and update their own hidden state based on these messages. The GNN’s task is to compute a hidden state at the root node, which can then be used to make predictions.
In addition to supervised training, GNNs can also be trained in an unsupervised manner to compute continuous representations of the graph structure. TF-GNN supports a fine-grained specification of unsupervised objectives for heterogeneous graphs.
TF-GNN provides various levels of abstraction for building and training GNNs. Users can choose from predefined models bundled with the library or create their own models using primitives for passing data around the graph. The library also includes a training orchestration tool called the TF-GNN Runner, which simplifies the training process and supports distributed training and padding for fixed shapes.
Overall, TF-GNN aims to advance the application of GNNs in TensorFlow and make it easier for researchers and developers to build and train GNN models at large scales.
Source link