BT

最新技術を追い求めるデベロッパのための情報コミュニティ

寄稿

Topics

地域を選ぶ

InfoQ ホームページ アーティクル Deep Java Library(DJL)の紹介

Deep Java Library(DJL)の紹介

ブックマーク

原文(投稿日:2019/12/19)へのリンク

AmazonのDJLは、マシンラーニング(ML)とディープラーニング(DL)モデルをJavaでネイティブ開発するためのディープラーニングツールキットです。ディープラーニングフレームワークを簡単に使用できるようにしてくれます。re:Invent 2019に合わせてオープンソース公開されたDJLは、推論のトレーニングやテスト、実行のための高レベルなAPIセットを提供します。Javaの開発者が自分自身でモデルを開発したり、データサイエンティストがPythonで開発したトレーニング済モデルをJavaコードで使用することが可能になります。

さらにDJLは、エンジンとディープラーニングフレームワークに依存しないことによって、"write once, run anywhere(WORA)"というJavaのモットーを実現しています。一度作成したコードは、任意のエンジン上で実行可能です。現時点では、ディープニューラルネットワークを簡易化するMLエンジンである、Apache MXNet上での実装が提供されています。DJLのAPIはJNA(Java Native Access)を使って、対応するApache MXNetオペレーションを呼び出します。さらにDJLは、ハードウェアコンフィギュレーションに基づいたCPU/GPU自動検出機能を提供するインフラストラクチャ管理を行うことで、良好なパフォーマンスを実現しています。

APIでは、モデル開発で一般的に使用される機能を抽象化することで、Java開発者の既存ナレッジをMLへ簡単に転換して活用できるようにします。DJLの動作を実際に確認するため、簡単な例として靴(履物)の分類モデルの開発で使用してみましょう。

マシンラーニングのライフサイクル

靴の分類モデルの作成は、マシンラーニングのライフサイクルに沿って行われます。MLのライフサイクルは一般的なソフトウェア開発ライフサイクルとは違って、6つの独立したステップから構成されます。

  1. データの取得
  2. データのクリーニングと準備作業
  3. モデルの生成
  4. モデルの評価
  5. モデルの展開
  6. モデルからの予測(あるいは推論)の取得

ライフサイクルの実行結果として得られるのは、コンサルトからの答(ないし推論)を返すことができるマシンラーニングモデルです。


モデルとは、データの中から見いだされた傾向やパターンの数学的モデルに他なりません。優れたデータは、あらゆるMLプロジェクトの礎になります。

ステップ1では、データは信頼できるソースから取得されます。ステップ2では、データはクリーニングされ、変換されて、マシンが学習可能なフォーマットに落とし込まれます。クリーニングと変換というこの2つのプロセスは、マシンラーニングライフサイクルにおいて、最も時間を要する部分であることが少なくありません。DJLでは、トランスレータを使用したイメージのプリプロセスを可能にすることで、このプロセスを簡略化しています。トランスレータは、期待されるパラメータに基づいたイメージサイズの変更や、カラーからグレースケールへの変換といった作業を行うことができます。

マシンラーニングを新たに始めようとする開発者は、このクリーニングや変換に要する時間を軽視していることが多いので、トランスレータは、プロセスの立ち上げをスムーズにする上で大きな役割を果たしてくれます。トレーニングプロセスのステップ3では、マシンラーニングアルゴリズムがデータ上に複数のパス(あるいはエポック)を生成して、それらを学習し、さまざまなタイプの履物について学ぼうとします。履物に関して発見したトレンドやパターンはモデル内に格納されます。ステップ4は、履物を正しく識別する能力を判断するためにモデルを評価する過程において、トレーニングの一部として実行されます。誤りが見つかればここで訂正されます。ステップ5では、モデルが運用環境にデプロイされます。モデルの運用が開始されれば、ステップ6で他システムからのモデルの利用が可能になります。

モデルは通常、コードに動的にロードするか、あるいはRESTベースのHTTPSエンドポイント経由でアクセスします。

データ

履物分類モデルは多クラス分類(multiclass classification)型のコンピュータビジョン(CV)モデルであり、教師あり学習(supervised learning)を使ってトレーニングされ、履物をブーツ、サンダル、シューズ、スリッパの4つに分類します。教師あり学習では、マシンラーニングを行う手段として、推論しようとするターゲット(あるいは答)によって事前にラベル付けされたデータが必要になります。

履物分類モデルでデータソースとして使用するUTZappos50Kデータセットは、オースチンにあるテキサス大学の提供によるもので、学術上の非営利目的で無償使用することが可能です。データセットは、Zappos.comから収集した、50,025のラベル付きカタログイメージで構成されています。

履物データはローカルに保存され、ローカルフォルダからイメージを取得する、DJLのImageFolderデータセットを使用してロードされます。

// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";

// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";

//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);

今回のデータのローカル構造では、ブーツの識別ラベル(アンクル、ニーハイ、ミッドカーフ、オーバーニー)のような、UTZappos50Kの最も詳細な識別レベルまでは使用せず、最高レベルであるブーツ、サンダル,シューズ、スリッパの識別に留めました。

DJL用語でのデータセットは、単にトレーニングデータを保持しています。データのダウンロード、データの抽出、データのトレーニングセットと評価セットへの自動分離に使用することのできるデータセット実装が用意されています。

自動分離が有用なのは、モデルをトレーニングした同じデータを、モデルパフォーマンスの評価で使わないことが重要であるためです。トレーニングデータセットは、履物データからトレンドとパターンを見つけ出すために、モデルが使用します。評価データセットは、履物の区別に関するモデルの正確性をバイアスなく見積もることによって、モデルのパフォーマンスを評価するために使用されます。

トレーニングに使用したものと同じデータを使って評価が行われた場合には、テストに使用したデータをモデルがすでに知っているのですから、モデルがシューズを分類する能力に対する信頼性ははるかに低いものになってしまいます。現実の世界においても、教師が授業中に示したものとまったく同じ問題をテストに使用するようなことはないはずです。そのような方法では、生徒の本当の知識や理解度を測ることはできません。マシンラーニングモデルでも、同じことが当てはまるのです。

トレーニング

トレーニング用と評価用のデータセットが用意できたので、さっそくモデルのトレーニング(あるいは生成)するニューラルネットワークを使用しましょう。

public final class Training extends AbstractTraining {

     . . .

     @Override
     protected void train(Arguments arguments) throws IOException {

          // identify the location of the training data
          String trainingDatasetRoot = "src/test/resources/imagefolder/train";

          // identify the location of the validation data
          String validateDatasetRoot = "src/test/resources/imagefolder/validate";

          //create training data ImageFolder dataset
          ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

          //create validation data ImageFolder dataset
          ImageFolder validateDataset = initDataset(validateDatasetRoot);

          . . .

          try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
             TrainingConfig config = setupTrainingConfig(loss);

             try (Trainer trainer = model.newTrainer(config)) {
                 trainer.setMetrics(metrics);

                 trainer.setTrainingListener(this);

                 Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);

                 // initialize trainer with proper input shape
                 trainer.initialize(inputShape);

                 //find the patterns in data
                 fit(trainer, trainingDataset, validateDataset, "build/logs/training");

                 //set model properties
                 model.setProperty("Epoch", String.valueOf(EPOCHS));
                 model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));

                // save the model after done training for inference later
                //model saved as shoeclassifier-0000.params
                model.save(Paths.get(modelParamsPath), modelParamsName);
             }
          }
     }

 }

最初のステップでは、Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)をコールしてモデルインスタンスを生成します。マシンラーニングの一形式であるディープラーニングでは、モデルのトレーニングにニューラルネットワークを使用します。ニューラルネットワークは、人の脳にあるニューロンをモデル化したものです。ニューロンは、他のセルに情報(あるいはデータ)を転送するセルの単位です。

イメージ分類用のニューラルネットワークとしては、ResNet-50が多く使用されています。名称の50は、オリジナル入力から最終推論までの間に50の学習(あるいはニューロン)層のあることを意味するものです。getModel()メソッドは空のモデルを生成し、ResNet-50ニューラルネットワークを構築した上で、そのニューラルネットワークをモデルにセットします。

public class Models {
   public static ai.djl.Model getModel(int numOfOutput, int height, int width) {
       //create new instance of an empty model
       ai.djl.Model model = ai.djl.Model.newInstance();

       //Block is a composable unit that forms a neural network; combine them
       //like Lego blocks to form a complex network
       Block resNet50 =
               //construct the network
               new ResNetV1.Builder()
                       .setImageShape(new Shape(3, height, width))
                       .setNumLayers(50)
                       .setOutSize(numOfOutput)
                       .build();

       //set the neural network to the model
       model.setBlock(resNet50);
       return model;
   }
}

次のステップではmodel.newTrainer(config)メソッドを呼び出して、Trainerのセットアップとコンフィギュレーションを行います。configオブジェクトの初期化はsetupTrainingConfig(loss)メソッドをコールして行います。このメソッドは、ネットワークのトレーニング方法を決定するためのトレーニング設定(またはハイパーパラメータ)を設定します。

次のステップは、以下の情報をセットすることで、Trainerに次のような機能追加を可能にします。

  • trainer.setMetrics(metrics)を使用したMetrics
  • trainer.setTrainingListener(this)を使用したリスナのトレーニング
  • trainer.initialize(inputShape)を使用した適切なインプット

Metricsは、トレーニング中の重要なパフォーマンス指標(KPIs, Key Performance Indicators)の収集とレポートを行って、トレーニングのパフォーマンスや安定性の分析や監視に使用できるようにします。次のステップでは、fit(trainer, trainingDataset, valifdateDataset, "build/log/training")メソッドをコールしてトレーニングプロセスを起動し、トレーニングデータ全体を繰り返した後、検出したパターンをモデルに格納します。トレーニングの最後には、 model.save(Paths.get(modelParamsPath), modelParamsName)</ t3>メソッドを使用して、パフォーマンスの高い検証済モデルアーティファクトがプロパティとともにローカルに保存されます。

トレーニングプロセス中にレポートされたメトリクスは次のようになります。各エポック(あるいはパス)毎にモデルの正確度が向上し、エポック9の最終的なトレーニング精度が90パーセントに達している点に注目してください。

推論

モデルの生成が完了すれば、分類(あるいはターゲット)の分からない、新たなデータを対象にした推論(あるいは予測)の実行に使用できるようになります。


private Classifications predict() throws IOException, ModelException, TranslateException  {
   //the location to the model saved during training
   String modelParamsPath = "build/logs";

   //the name of the model set during training
   String modelParamsName = "shoeclassifier";

   // the path of image to classify
   String imageFilePath = "src/test/resources/slippers.jpg";

   //Load the image file from the path
   BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));

   //holds the probability score per label
   Classifications predictResult;

   try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
       //load the model
       model.load(Paths.get(modelParamsPath), modelParamsName);

       //define a translator for pre and post processing
       Translator<BufferedImage, Classifications> translator = new MyTranslator();

       //run the inference using a Predictor
       try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) {
           predictResult = predictor.predict(img);
       }
   }

   return predictResult;
}

モデルと分類するイメージに対する必要なパスを設定した後、Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)メソッドを使用して空のモデルインスタンスを用意して、model.load(Paths.get(modelParamsPath), modelParamsName)メソッドでその初期化を行います。これにより、これまでのステップでトレーニングされたモデルがロードされます。

次に、model.newPrediction(translator)メソッドを使ってTranslatorを指定して、Predictorを初期化します。DJL用語でのTranslatorは、モデルのプリプロセスとポストプロセスの機能を提供します。例えばCVモデルでは、イメージをグレースケールで再形成する必要がありますが、Translatorはこれを行うことができます。PredictorはロードしたModel上で、poredictor.predict(img)メソッドを使って分類するイメージを渡すことで、推論の実行を可能にします。

この例では単一推論ですが、DJLはバッチ推論もサポートしています。推論結果はpredictResultに格納されます。この中には、ラベル毎の確率評価(probability estimate)が含まれています。

(イメージ毎の)推論はそれぞれの確率評価とともに、次のように表示されます。

イメージ 確率スコア(Probability Score)

[INFO ] — [
                class: "0", probability: 0.98985
                class: "1", probability: 0.00225
                class: "2", probability: 0.00224
                class: "3", probability: 0.00564
            ]
Class 0は、98.98パーセントの確率でブーツであると示されています。

[INFO ] — [
               class: "0", probability: 0.02111
               class: "1", probability: 0.76524
               class: "2", probability: 0.01159
               class: "3", probability: 0.20204
          ]
Class 1は、76.52パーセントの確率でサンダルであると示されています。

[INFO ] — [
               class: "0", probability: 0.05523
               class: "1", probability: 0.01417
               class: "2", probability: 0.87900
               class: "3", probability: 0.05158
              ]

Class 2は、87.90パーセントの確率でシューズであると示されています。

[INFO ] — [
                class: "0", probability: 0.00003
                class: "1", probability: 0.01133
                class: "2", probability: 0.00179
                class: "3", probability: 0.98682
              ]
Class 3は、98.68パーセントの確率でスリッパであると示されています。

DJLは、他のJavaライブラリと同じようなネイティブJava開発のエクスペリエンスと機能を提供します。APIは、ディープラーニング作業を完遂させるためのベストプラクティスに開発者を導くように設計されています。DJLを始めるには、MLのライフサイクルを十分理解する必要があります。MLが初めてならば、まず概要を読むか、あるいはInfoQのアーティクルシリーズ"an introduction to machine learning for software developers"から始めるとよいでしょう。ライフサイクルと一般的なML用語が理解できれば、DIJのAPIはすぐに理解できるようになります。

AmazonはDJLをオープンソースとして公開しています。ツールキットに関する詳細な情報はDJLのWebサイトや、Java Library API Specificationのページで確認することができます。サンプルをより深く調べるために、履物分類モデルのコードをレビューすることも可能です。

著者について

Kesha Williams氏は、受賞経験のあるソフトウェアエンジニアであり、マシンラーニングの実践家で、24年のキャリアを持つA Cloud Guruのテクニカルインストラクタです。氏は米国や欧州やアジアで数千人のJavaソフトウェアエンジニアのトレーニングや指導を行うと同時に、大学レベルでの教育を行ってきました。革新的な技術を実証するために氏は、イノベーションチームを定期的に指導し、世界中のカンファレンスで学んだことを公開しています・。TEDステージでマシンラーニングについて講演し、TEDのSpotlight Presentation Academyを受賞した経験もあります。さらに、人工知能分野における先駆的な仕事によって、Alexa ChampionとAmazonのAWS Machine Learning Heroの両方から名誉ある賞を受けています。さらに余暇には、オンラインのソーシャルおよびプロフェッショナルネットワーキングプラットフォームであるColors of STEMを使用して、ハイテク界の女性たちを指導しています。

この記事に星をつける

おすすめ度
スタイル

こんにちは

コメントするには InfoQアカウントの登録 または が必要です。InfoQ に登録するとさまざまなことができます。

アカウント登録をしてInfoQをお楽しみください。

HTML: a,b,br,blockquote,i,li,pre,u,ul,p

コミュニティコメント

HTML: a,b,br,blockquote,i,li,pre,u,ul,p

HTML: a,b,br,blockquote,i,li,pre,u,ul,p

BT