Facebook AI Research(FAIR)とゲルフ大学(University of Guelph)の共同チームは、ディープラーニング・ニューラルネットワークの初期パラメータを予測するGraph HyperNetworks(GHN-2)メタモデルの強化版をオープンソースとして公開した。GHN-2は単一CPU上で1秒未満で動作し、CIFAR-10データセット上において、追加的なトレーニングを必要とせず、最高77パーセントのtop-1精度でコンピュータビジョン(CV)ネットワークの値を予測することができる。
チームの研究者らは、同システムと論文に掲載された一連の実験について、次回のConference on Neural Information Processing Systems(NeurIPS)で発表する予定である、と説明している。ディープラーニングモデルの初期パラメータ予測という問題を解決するために、チームはまず、計算グラフで表現されたニューラルネットワークアーキテクチャのサンプルを100万件収めた、"DeepNets-1M"というデータセットを用意した。次に、このデータセット上で、改良版GHN(graph hyper-network)をメタラーニングを使ってトレーニングすることによって、未知のネットワークアーキテクチャに対するパラメータの予測に使用できるようにしたのだ。作成されたメタモデルは、トレーニングで使用したものよりはるかに大規模なアーキテクチャでも"驚くほど良好に"処理できるものになっており、24MパラメータのResNet-50の初期化に使用した場合には、勾配の更新を行うことなく、CIFAR-10で60パーセントの精度で達成するパラメータを見つけ出している。トレーニングを行ったメタデータとコードに合わせて、DeepNets-1Mトレーニングデータセットと、いくつかのベンチマークテストデータセットもリリースされた。筆頭著者のBoris Knyazev氏によると、
今回の論文によって、私たちは、手書きのオプティマイザから単一メタモデルへのリプレースという目標に対して、さらに一歩近付きました。私たちのメタモデルは、1回の順方向パスを行うのみで、ほぼすべてのニューラルネットワークのパラメータを予測することができます。
データセットを使用したディープラーニングモデルのトレーニングは、トレーニングデータで評価されたモデルの損失関数を最小化するモデルパラメータの検出として形式化される。これには通常、確率的勾配降下法(stochastic gradient descent)(SGD)やAdamといった、反復形式の最適化アルゴリズムが使用される。この方法の欠点は、最小化のために多くの計算時間と相当量のエネルギを必要とすることだ。実際に、最も優れたネットワークアーキテクチャとハイパーパラメータのセットを見つけるために、多数のモデルのトレーニングをすることが多く、コストを悪化させる原因になっている。
モデルトレーニングのコストを低減するため、Facebookのチームは、特定のデータセットでトレーニングされたハイパーモデルを開発した。このハイパーモデルは、提案されたネットワークアーキテクチャに対して、パフォーマンスの優れたパラメータを予測することができる。メタラーニングタスクの開発では、Differentiable ARchiTecture Search(DARTS)と呼ばれるネットワークアーキテクチャ探索(NAS)アルゴリズムからヒントを得た。このタスクにはImageNetのようなドメイン固有データセットと、計算グラフとして表現されたモデルネットワークアーキテクチャのトレーニングセットが必要となる。その上でチームは、グラフラーニングの技術を使ってハイパーモデルのトレーニングを行った。ハイパーモデルの目的は、入力されたネットワークアーキテクチャに対して、そのドメイン固有データセットにおけるネットワークの損失を最小限にするパラメータを予測することだ。
出典: https://github.com/facebookresearch/ppuda
開発した技術を評価するために、チームは、ImageNetとCIFAR-10という2つのドメイン固有データセットを使ってメタモデルをトレーニングした。その上で、GHN-2が生成したパラメータのパフォーマンスを、別の2つのベースラインメタモデルで生成したもの、および標準的な反復形式のオプティマイザで生成したものと比較した。パラメータの予測は、メタモデルのトレーニングでは使用しなかったネットワークアーキテクチャを対象に実施した。結果としてGHN-2は、ベースラインのメタモデルを"大幅に"上回るパフォーマンスを示すと同時に、反復形式のオプティマイザとの比較では、GHN-2が順方向パス1回のみで予測したパラメータが、"SGDがCIFAR-10およびImageNet上でそれぞれ~2,500および~5,000のイテレーションを行ったものと同等の精度"を達成した。
ただし、GHN-2モデルにはいくつかの欠点がある。まず、新たなメタモデルを、ドメイン固有のデータセット毎にトレーニングしなくてはならない。また、GHN-2の予測するパラメータはランダムな選択に比べれば優れているものの、"アーキテクチャによっては"あまり正確ではない可能性もある、とKnyazev氏は述べている。論文に関するRedditの議論では、あるユーザが次のような指摘をしている。
論文の著者もツイートでも指摘しているとおり、少なくとも重みの初期化をランダム分布するよりはずっと優れているでしょう ... 非常に興味深く、潜在的に有用な開発であるラーニングネットワークアーキテクチャの一部のクラスでは、ある程度の普及を見るかも知れません。
トレーニング済のGHN-2モデルとコードは、DeepNets-1MデータセットとともにGitHubで公開されている。