BT

Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ

Topics

Choose your language

InfoQ Homepage Articles Getting to Know Deep Java Library (DJL)

Getting to Know Deep Java Library (DJL)

This item in japanese

Lire ce contenu en français

Key Takeaways

  • Developers can build, train, and deploy machine learning (ML) and deep learning (DL) models using Java and their favorite IDE
  • DJL simplifies the use of deep learning (DL) frameworks and currently supports Apache MXNet
  • The open-source nature of DJL should be mutually beneficial for the toolkit and its users
  • DJL is engine agnostic, which means developers can write code once and run it on any engine
  • Java developers should have an understanding of the ML lifecycle and common ML terms before attempting to use DJL

Amazon’s DJL is a deep learning toolkit used to develop machine learning (ML) and deep learning (DL) models natively in Java while simplifying the use of deep learning frameworks. A toolkit open-sourced just in time for re:Invent 2019, DJL provides a set of high-level APIs to train, test and run inference. Java developers can develop their own models or utilize pre-trained models developed by data scientists in Python from their Java code.

DJL stays true to Java’s motto, "write once, run anywhere (WORA)", by being engine and deep learning framework-agnostic. Developers can write code once that runs on any engine. DJL currently provides an implementation for Apache MXNet, an ML engine that eases the development of deep neural networks. DJL APIs use JNA, Java Native Access, to call the corresponding Apache MXNet operations. DJL orchestrates infrastructure management providing automatic CPU/GPU detection based on the hardware configuration to ensure good performance.

DJL APIs abstract commonly used functions to develop models, enabling Java developers to leverage existing knowledge to ease the transition to ML. To see DJL in action, let’s use developing a footwear classification model as a simple example.

The Machine Learning Lifecycle

The machine learning lifecycle is followed to produce the footwear classification model. The ML lifecycle is different from the traditional software development lifecycle and consists of six concrete steps:

  1. Obtain the data
  2. Clean and prepare the data
  3. Generate the model
  4. Evaluate the model
  5. Deploy the model
  6. Obtain a prediction (or inference) from the model

The end result of the lifecycle is a machine learning model that can be consulted and return an answer (or prediction).


A model is simply a mathematical representation of trends and patterns found in data. Good data are the foundation for all ML projects.

In step 1, data are obtained from a reputable source. In step 2, the data are cleaned, transformed, and put in a format that a machine can learn from. The cleaning and transformation process is often the most time intensive piece of the machine learning lifecycle. DJL eases this process for developers by providing the ability to preprocess images using translators. Translators can do tasks such as resize images based on expected parameters or convert images from color to grayscale.

Developers transitioning to machine learning often underestimate the time needed to clean and transform data so translators are a great way to jumpstart the process. During the training process, step 3, a machine learning algorithm makes multiple passes (or epochs) over the data, studying them, trying to learn the different types of footwear. The trends and patterns found, as they relate to footwear, are stored in the model. Step 4 occurs as a part of training when the model is evaluated to determine how well it is at identifying footwear; if mistakes are uncovered, they are corrected. In step 5, the model is deployed to a production environment. Once the model is in production, step 6 allows the model to be consumed by other systems.

Typically, models can be dynamically loaded in your code or accessed via a REST-based HTTPS endpoint.

Data

The footwear classification model is a multiclass classification computer vision (CV) model, trained using supervised learning, that classifies footwear in one of four class labels:  boots, sandals, shoes, or slippers. Supervised learning must include data that is already labeled with the target (or answer) you’re trying to predict; this is how the machine learns.

The data source for the footwear classification model is the UTZappos50k dataset provided by The University of Texas at Austin and is freely available for academic, non-commercial use.  The shoe dataset consists of 50,025 labeled catalog images collected from Zappos.com.

The footwear data were saved locally and loaded using DJL’s ImageFolder dataset, which retrieves images from a local folder.

// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";

// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";

//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);

When structuring the data locally, I didn’t go down to the most granular level identified by the UTZappos50k dataset, such as the ankle, knee-high, mid-calf, over the knee, etc. classification labels for boots. My local data are kept at the highest level of classification, which includes only boots, sandals, shoes, and slippers.

In DJL terms, a dataset simply holds the training data. There are dataset implementations that can be used to download data (based on the URL you provide), extract data, and automatically separate data into training and validation sets.

The automatic separation is a useful feature as it is important to never use the same data the model was trained with to validate the model’s performance. The training dataset is used by the model to find trends and patterns in footwear data. The validation dataset is used to qualify the model’s performance by providing an unbiased estimate of the model’s accuracy at classifying footwear.

If the model were validated using the same data it was trained with, our confidence in the model’s ability to classify shoes would be much lower because the model is being tested with data it has already seen. In the real world, a teacher would not test a student using the exact same questions provided on a study guide because that would not measure a student’s true knowledge or understanding of the material; subsequently, the same concept holds true for machine learning models.

Training

Now that we have the footwear data separated into training and validation sets, let’s use a neural network to train (or produce) the model.

public final class Training extends AbstractTraining {

     . . .

     @Override
     protected void train(Arguments arguments) throws IOException {

          // identify the location of the training data
          String trainingDatasetRoot = "src/test/resources/imagefolder/train";

          // identify the location of the validation data
          String validateDatasetRoot = "src/test/resources/imagefolder/validate";

          //create training data ImageFolder dataset
          ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

          //create validation data ImageFolder dataset
          ImageFolder validateDataset = initDataset(validateDatasetRoot);

          . . .

          try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
             TrainingConfig config = setupTrainingConfig(loss);

             try (Trainer trainer = model.newTrainer(config)) {
                 trainer.setMetrics(metrics);

                 trainer.setTrainingListener(this);

                 Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);

                 // initialize trainer with proper input shape
                 trainer.initialize(inputShape);

                 //find the patterns in data
                 fit(trainer, trainingDataset, validateDataset, "build/logs/training");

                 //set model properties
                 model.setProperty("Epoch", String.valueOf(EPOCHS));
                 model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));

                // save the model after done training for inference later
                //model saved as shoeclassifier-0000.params
                model.save(Paths.get(modelParamsPath), modelParamsName);
             }
          }
     }

 }

The first step is to get a model instance by calling Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH). Deep learning, a form of machine learning, uses a neural network in order to train the model. A neural network is modeled after neurons in the human brain; neurons are cells that transmit information (or data) to other cells.

ResNet-50 is a neural network often used with image classification; the 50 indicates there are 50 layers of learning (or neurons) between the original input data and the final prediction. The getModel() method creates an empty model, constructs a ResNet-50 neural network, and sets the neural network to the model.

public class Models {
   public static ai.djl.Model getModel(int numOfOutput, int height, int width) {
       //create new instance of an empty model
       ai.djl.Model model = ai.djl.Model.newInstance();

       //Block is a composable unit that forms a neural network; combine them
       //like Lego blocks to form a complex network
       Block resNet50 =
               //construct the network
               new ResNetV1.Builder()
                       .setImageShape(new Shape(3, height, width))
                       .setNumLayers(50)
                       .setOutSize(numOfOutput)
                       .build();

       //set the neural network to the model
       model.setBlock(resNet50);
       return model;
   }
}

The next step is to setup and configure a Trainer by calling the model.newTrainer(config) method. The config object was initialized by calling the setupTrainingConfig(loss) method, which sets the training configuration (or hyperparameters) to determine how the network is trained.

The next steps allow us to add features to the Trainer by setting:

  • Metrics using trainer.setMetrics(metrics)
  • a training listener using trainer.setTrainingListener(this)
  • the proper input shape using trainer.initialize(inputShape)

Metrics collect and report key performance indicators (KPIs) during training that can be used to analyze and monitor training performance and stability. The next step is to kick off the training process by calling the fit(trainer, trainingDataset, validateDataset, "build/logs/training") method, which iterates over the training data and stores the patterns found in the model. At the end of training, a well-performing validated model artifact is saved locally along with its properties using the model.save(Paths.get(modelParamsPath), modelParamsName) method.

The metrics reported during the training process are shown below. Notice that with each epoch (or pass) the accuracy of the model improves; the final training accuracy for epoch 9 is 90%.

Inference

Now that we’ve generated the model, it can be used to perform inference (or prediction) on new data for which we do not know the classification (or target).

private Classifications predict() throws IOException, ModelException, TranslateException  {
   //the location to the model saved during training
   String modelParamsPath = "build/logs";

   //the name of the model set during training
   String modelParamsName = "shoeclassifier";

   // the path of image to classify
   String imageFilePath = "src/test/resources/slippers.jpg";

   //Load the image file from the path
   BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));

   //holds the probability score per label
   Classifications predictResult;

   try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
       //load the model
       model.load(Paths.get(modelParamsPath), modelParamsName);

       //define a translator for pre and post processing
       Translator<BufferedImage, Classifications> translator = new MyTranslator();

       //run the inference using a Predictor
       try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) {
           predictResult = predictor.predict(img);
       }
   }

   return predictResult;
}

After setting the necessary paths to the model and the image to be classified, obtain an empty model instance using the Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) method and initialize it using the model.load(Paths.get(modelParamsPath), modelParamsName) method. This loads the model that was trained in the previous step.

Next, initialize a Predictor, with a specified Translator, using the model.newPredictor(translator) method. In DJL terms, a Translator provides model pre-processing and post-processing functionality. For example, with CV models, images need to be reshaped to grayscale; a Translator can do this. The Predictor allows us to perform inference on the loaded Model using the predictor.predict(img) method, passing in the image to classify.

This example shows a single prediction, but DJL also supports batch predictions. The inference is stored in predictResult, which contains the probability estimate per label.

The inferences (per image) are shown below with their corresponding probability scores.

Image Probability Score

[INFO ] - [
                class: "0", probability: 0.98985
                class: "1", probability: 0.00225
                class: "2", probability: 0.00224
                class: "3", probability: 0.00564
            ]
Class 0 represents boots with a probability score of 98.98%.

[INFO ] - [
               class: "0", probability: 0.02111
               class: "1", probability: 0.76524
               class: "2", probability: 0.01159
               class: "3", probability: 0.20204
          ]
Class 1 represents sandals with a probability score of 76.52%.

[INFO ] - [
               class: "0", probability: 0.05523
               class: "1", probability: 0.01417
               class: "2", probability: 0.87900
               class: "3", probability: 0.05158
              ]

Class 2 represents shoes with a probability score of 87.90%.

[INFO ] - [
                class: "0", probability: 0.00003
                class: "1", probability: 0.01133
                class: "2", probability: 0.00179
                class: "3", probability: 0.98682
              ]
Class 3 represents slippers with a probability score of 98.68%.

DJL provides a native Java development experience and functions as any other Java library would. The APIs are designed to guide developers with best practices to accomplish deep learning tasks. Before starting with DJL, a good understanding of the ML lifecycle is needed. If you’re new to ML, read an overview or start with InfoQ’s article series, an introduction to machine learning for software developers. After understanding the lifecycle and common ML terms, developers can quickly come up to speed on DJL’s APIs.

Amazon has open-sourced DJL, where further detailed information about the toolkit can be found on the DJL website and Java Library API Specification page. The code for the footwear classification model can be reviewed to further explore the examples.

About the Author

Kesha Williams is an award-winning software engineer, machine learning practitioner, and technical instructor at A Cloud Guru with 24 years’ experience. She's trained and mentored thousands of Java software engineers in the US, Europe, and Asia while teaching at the university level. She routinely leads innovation teams in proving out emerging technologies and shares her learnings at conferences across the globe. She's spoken about machine learning on the TED stage as a winner of TED’s Spotlight Presentation Academy. Additionally, her pioneering work in the field of artificial intelligence earned her the distinction of both Alexa Champion and AWS Machine Learning Hero from Amazon. In her spare time, she mentors women in tech through her online social & professional networking platform, Colors of STEM.

Rate this Article

Adoption
Style

BT