BT

Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ

Topics

Choose your language

InfoQ Homepage News Facebook Open-Sources GHN-2 AI for Fast Initialization of Deep-Learning Models

Facebook Open-Sources GHN-2 AI for Fast Initialization of Deep-Learning Models

This item in japanese

A team from Facebook AI Research (FAIR) and the University of Guelph have open-sourced an improved Graph HyperNetworks (GHN-2) meta-model that predicts initial parameters for deep-learning neural networks. GHN-2 executes in less than a second on a CPU and predicts values for computer vision (CV) networks that achieve up to 77% top-1 accuracy on CIFAR-10 with no additional training.

The researchers described the system and a series of experiments in a paper accepted for the upcoming Conference on Neural Information Processing Systems (NeurIPS). To solve the problem of predicting initial parameters for deep-learning models, the team generated a dataset called DeepNets-1M that contains one million examples of neural network architectures represented as computational graphs. They then used meta-learning to train a modified graph hyper-network (GHN) on this dataset, which can then be used to predict parameters for an unseen network architecture. The resulting meta-model is "surprisingly good" at the task, even for architectures much larger than the ones used in training. When used to initialize a 24M-parameter ResNet-50, the meta-model found parameters that achieved 60% accuracy on CIFAR-10 with no gradient updates. Along with their trained meta-model and code, the team released the DeepNets-1M training dataset as well as several benchmark test datasets. According to lead author Boris Knyazev,

Based on our...paper, we are one step closer to replacing hand-designed optimizers with a single meta-model. Our meta-model can predict parameters for almost any neural network in just one forward pass.

Training a deep-learning model on a dataset is formalized as finding a set of model parameters that minimizes the model's loss function evaluated on the training data. This is typically done by using an iterative optimization algorithm, such as stochastic gradient descent (SGD) or Adam. The drawback to this method is that the minimization can take many hours of computation and a good deal of energy. In practice, researchers will often train many models in order to find the best network architecture and set of hyperparameters, compounding the cost.

To help reduce the cost of training models, the Facebook team created a hyper-model that is trained for a specific dataset. Given a proposed network architecture, the hyper-model can predict performant parameters for the network. Inspired by work on a network architecture search (NAS) algorithm called Differentiable ARchiTecture Search (DARTS), the team formulated a meta-learning task. This task requires a domain-specific dataset, such as ImageNet, as well as a training set of model network architectures expressed as computational graphs. The team then trained a hyper-model using graph-learning techniques; the hyper-model's objective is to predict parameters for the input network architectures that minimize the networks' loss on the domain-specific data.

GHN-2 Overview

Source: https://github.com/facebookresearch/ppuda

To assess the performance of their technique, the team trained meta-models for two domain-specific datasets: ImageNet and CIFAR-10. They compared the performance of parameters generated by GHN-2 to those generated by two other baseline meta-models as well as to model parameters produced by standard iterative optimizers. The parameters were predicted for a set of network architectures not used for training the meta-models. GHN-2 "significantly" outperformed the baseline meta-models. Compared to iterative optimizers, the parameters predicted by GHN-2 with only a single forward pass achieved "an accuracy similar to ∼2500 and ∼5000 iterations of SGD on CIFAR-10 and ImageNet respectively."

The GHN-2 model does have some drawbacks. First, a new meta-model must be trained for each domain-specific dataset. Also, although GHN-2 can predict parameters that outperform random choices, Knyazev notes that "depending on the architecture," the predictions may not be very accurate. In a Reddit discussion about the paper, one user noted

At the very least, as the author’s tweet thread points out, the predicted parameters are likely a lot better than a random distribution for weight initialization...assuming it generalizes a bit within some class of learning network architectures that is a very interesting and potentially useful development.

The trained GHN-2 model and code as well as the DeepNets-1M dataset are available on GitHub.

Rate this Article

Adoption
Style

BT