Posted by Carlos Esteves and Ameesh Makadia, Research Scientists, Google Research, Athena Team
Typical deep learning models for computer vision, such as convolutional neural networks (CNNs) and vision transformers (ViT), are designed to process signals in planar (flat) spaces. However, in scientific applications, we often encounter data that is represented on a sphere, such as variables sampled from the Earth’s atmosphere or cosmological data. Using traditional methods designed for planar images to process spherical signals presents challenges. One problem is the sampling issue, as it is difficult to define uniform grids on a sphere without distortion. Additionally, signals and patterns on a sphere are often complicated by rotations, requiring models to address this rotation equivariance.
To address these challenges, we introduce an open-source library in JAX for deep learning on spherical surfaces in our paper “Scaling Spherical CNNs” presented at ICML 2023. This library allows for efficient processing of spherical signals and has been demonstrated to achieve state-of-the-art performance in weather forecasting and molecular property prediction benchmarks.
Spherical CNNs solve the sampling problem and rotation robustness issues by leveraging spherical convolution and cross-correlation operations, computed via generalized Fourier transforms. While convolution with small filters is faster for planar surfaces, the computational cost for spherical inputs has limited the application of spherical CNNs to small models and low-resolution datasets. In our work, we have implemented spherical convolutions in JAX with a focus on speed, enabling distributed training over a large number of TPUs using data parallelism. We have also introduced new activation and normalization layers, as well as a new residual block, to improve accuracy and efficiency.
We have applied our models to molecular property regression and weather forecasting tasks. In molecular property regression, we map molecules to a set of spherical functions using physics-based interactions between atoms. Our spherical CNNs achieve state-of-the-art performance in predicting molecular properties. In weather forecasting, our models are able to handle atmospheric data natively presented on the sphere and effectively capture repeating patterns. We have outperformed or matched neural weather models based on conventional CNNs.
We have made our library for efficient spherical CNNs available in JAX, and we believe it will be valuable in various scientific applications, as well as computer vision and 3D vision. Weather forecasting is an active area of research at Google, and we are continuously working on building more accurate and robust models. We also aim to provide tools that enable further advancements in the research community, such as the recently released WeatherBench 2 dataset.
We would like to acknowledge our collaborators and colleagues for their contributions and support in this project.
Source link