Key Takeaways
- Developers can build, train, and deploy machine learning (ML) and deep learning (DL) models using Java and their favorite IDE
- DJL simplifies the use of deep learning (DL) frameworks and currently supports Apache MXNet
- The open-source nature of DJL should be mutually beneficial for the toolkit and its users
- DJL is engine agnostic, which means developers can write code once and run it on any engine
- Java developers should have an understanding of the ML lifecycle and common ML terms before attempting to use DJL
Amazon’s DJL is a deep learning toolkit used to develop machine learning (ML) and deep learning (DL) models natively in Java while simplifying the use of deep learning frameworks. A toolkit open-sourced just in time for re:Invent 2019, DJL provides a set of high-level APIs to train, test and run inference. Java developers can develop their own models or utilize pre-trained models developed by data scientists in Python from their Java code.
DJL stays true to Java’s motto, "write once, run anywhere (WORA)", by being engine and deep learning framework-agnostic. Developers can write code once that runs on any engine. DJL currently provides an implementation for Apache MXNet, an ML engine that eases the development of deep neural networks. DJL APIs use JNA, Java Native Access, to call the corresponding Apache MXNet operations. DJL orchestrates infrastructure management providing automatic CPU/GPU detection based on the hardware configuration to ensure good performance.
DJL APIs abstract commonly used functions to develop models, enabling Java developers to leverage existing knowledge to ease the transition to ML. To see DJL in action, let’s use developing a footwear classification model as a simple example.
The Machine Learning Lifecycle
The machine learning lifecycle is followed to produce the footwear classification model. The ML lifecycle is different from the traditional software development lifecycle and consists of six concrete steps:
- Obtain the data
- Clean and prepare the data
- Generate the model
- Evaluate the model
- Deploy the model
- Obtain a prediction (or inference) from the model
The end result of the lifecycle is a machine learning model that can be consulted and return an answer (or prediction).
A model is simply a mathematical representation of trends and patterns found in data. Good data are the foundation for all ML projects.
In step 1, data are obtained from a reputable source. In step 2, the data are cleaned, transformed, and put in a format that a machine can learn from. The cleaning and transformation process is often the most time intensive piece of the machine learning lifecycle. DJL eases this process for developers by providing the ability to preprocess images using translators. Translators can do tasks such as resize images based on expected parameters or convert images from color to grayscale.
Developers transitioning to machine learning often underestimate the time needed to clean and transform data so translators are a great way to jumpstart the process. During the training process, step 3, a machine learning algorithm makes multiple passes (or epochs) over the data, studying them, trying to learn the different types of footwear. The trends and patterns found, as they relate to footwear, are stored in the model. Step 4 occurs as a part of training when the model is evaluated to determine how well it is at identifying footwear; if mistakes are uncovered, they are corrected. In step 5, the model is deployed to a production environment. Once the model is in production, step 6 allows the model to be consumed by other systems.
Typically, models can be dynamically loaded in your code or accessed via a REST-based HTTPS endpoint.
Data
The footwear classification model is a multiclass classification computer vision (CV) model, trained using supervised learning, that classifies footwear in one of four class labels: boots, sandals, shoes, or slippers. Supervised learning must include data that is already labeled with the target (or answer) you’re trying to predict; this is how the machine learns.
The data source for the footwear classification model is the UTZappos50k dataset provided by The University of Texas at Austin and is freely available for academic, non-commercial use. The shoe dataset consists of 50,025 labeled catalog images collected from Zappos.com.
The footwear data were saved locally and loaded using DJL’s ImageFolder dataset, which retrieves images from a local folder.
// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";
//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);
When structuring the data locally, I didn’t go down to the most granular level identified by the UTZappos50k dataset, such as the ankle, knee-high, mid-calf, over the knee, etc. classification labels for boots. My local data are kept at the highest level of classification, which includes only boots, sandals, shoes, and slippers.
In DJL terms, a dataset simply holds the training data. There are dataset implementations that can be used to download data (based on the URL you provide), extract data, and automatically separate data into training and validation sets.
The automatic separation is a useful feature as it is important to never use the same data the model was trained with to validate the model’s performance. The training dataset is used by the model to find trends and patterns in footwear data. The validation dataset is used to qualify the model’s performance by providing an unbiased estimate of the model’s accuracy at classifying footwear.
If the model were validated using the same data it was trained with, our confidence in the model’s ability to classify shoes would be much lower because the model is being tested with data it has already seen. In the real world, a teacher would not test a student using the exact same questions provided on a study guide because that would not measure a student’s true knowledge or understanding of the material; subsequently, the same concept holds true for machine learning models.
Training
Now that we have the footwear data separated into training and validation sets, let’s use a neural network to train (or produce) the model.
public final class Training extends AbstractTraining {
. . .
@Override
protected void train(Arguments arguments) throws IOException {
// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";
//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);
. . .
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
TrainingConfig config = setupTrainingConfig(loss);
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(metrics);
trainer.setTrainingListener(this);
Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
//find the patterns in data
fit(trainer, trainingDataset, validateDataset, "build/logs/training");
//set model properties
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));
// save the model after done training for inference later
//model saved as shoeclassifier-0000.params
model.save(Paths.get(modelParamsPath), modelParamsName);
}
}
}
}
The first step is to get a model instance by calling Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH). Deep learning, a form of machine learning, uses a neural network in order to train the model. A neural network is modeled after neurons in the human brain; neurons are cells that transmit information (or data) to other cells.
ResNet-50 is a neural network often used with image classification; the 50 indicates there are 50 layers of learning (or neurons) between the original input data and the final prediction. The getModel() method creates an empty model, constructs a ResNet-50 neural network, and sets the neural network to the model.
public class Models {
public static ai.djl.Model getModel(int numOfOutput, int height, int width) {
//create new instance of an empty model
ai.djl.Model model = ai.djl.Model.newInstance();
//Block is a composable unit that forms a neural network; combine them
//like Lego blocks to form a complex network
Block resNet50 =
//construct the network
new ResNetV1.Builder()
.setImageShape(new Shape(3, height, width))
.setNumLayers(50)
.setOutSize(numOfOutput)
.build();
//set the neural network to the model
model.setBlock(resNet50);
return model;
}
}
The next step is to setup and configure a Trainer by calling the model.newTrainer(config) method. The config object was initialized by calling the setupTrainingConfig(loss) method, which sets the training configuration (or hyperparameters) to determine how the network is trained.
The next steps allow us to add features to the Trainer by setting:
- Metrics using
trainer.setMetrics(metrics)
- a training listener using
trainer.setTrainingListener(this)
- the proper input shape using
trainer.initialize(inputShape)
Metrics
collect and report key performance indicators (KPIs) during training that can be used to analyze and monitor training performance and stability. The next step is to kick off the training process by calling the fit(trainer, trainingDataset, validateDataset, "build/logs/training") method, which iterates over the training data and stores the patterns found in the model. At the end of training, a well-performing validated model artifact is saved locally along with its properties using the model.save(Paths.get(modelParamsPath), modelParamsName) method.
The metrics reported during the training process are shown below. Notice that with each epoch (or pass) the accuracy of the model improves; the final training accuracy for epoch 9 is 90%.
Inference
Now that we’ve generated the model, it can be used to perform inference (or prediction) on new data for which we do not know the classification (or target).
private Classifications predict() throws IOException, ModelException, TranslateException {
//the location to the model saved during training
String modelParamsPath = "build/logs";
//the name of the model set during training
String modelParamsName = "shoeclassifier";
// the path of image to classify
String imageFilePath = "src/test/resources/slippers.jpg";
//Load the image file from the path
BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));
//holds the probability score per label
Classifications predictResult;
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
//load the model
model.load(Paths.get(modelParamsPath), modelParamsName);
//define a translator for pre and post processing
Translator<BufferedImage, Classifications> translator = new MyTranslator();
//run the inference using a Predictor
try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) {
predictResult = predictor.predict(img);
}
}
return predictResult;
}
After setting the necessary paths to the model and the image to be classified, obtain an empty model instance using the Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) method and initialize it using the model.load(Paths.get(modelParamsPath), modelParamsName) method. This loads the model that was trained in the previous step.
Next, initialize a Predictor, with a specified Translator, using the model.newPredictor(translator) method. In DJL terms, a Translator provides model pre-processing and post-processing functionality. For example, with CV models, images need to be reshaped to grayscale; a Translator can do this. The Predictor allows us to perform inference on the loaded Model using the predictor.predict(img) method, passing in the image to classify.
This example shows a single prediction, but DJL also supports batch predictions. The inference is stored in predictResult, which contains the probability estimate per label.
The inferences (per image) are shown below with their corresponding probability scores.
Image | Probability Score |
---|---|
[INFO ] - [ |
|
[INFO ] - [ |
|
[INFO ] - [ |
|
[INFO ] - [ |
DJL provides a native Java development experience and functions as any other Java library would. The APIs are designed to guide developers with best practices to accomplish deep learning tasks. Before starting with DJL, a good understanding of the ML lifecycle is needed. If you’re new to ML, read an overview or start with InfoQ’s article series, an introduction to machine learning for software developers. After understanding the lifecycle and common ML terms, developers can quickly come up to speed on DJL’s APIs.
Amazon has open-sourced DJL, where further detailed information about the toolkit can be found on the DJL website and Java Library API Specification page. The code for the footwear classification model can be reviewed to further explore the examples.
About the Author
Kesha Williams is an award-winning software engineer, machine learning practitioner, and technical instructor at A Cloud Guru with 24 years’ experience. She's trained and mentored thousands of Java software engineers in the US, Europe, and Asia while teaching at the university level. She routinely leads innovation teams in proving out emerging technologies and shares her learnings at conferences across the globe. She's spoken about machine learning on the TED stage as a winner of TED’s Spotlight Presentation Academy. Additionally, her pioneering work in the field of artificial intelligence earned her the distinction of both Alexa Champion and AWS Machine Learning Hero from Amazon. In her spare time, she mentors women in tech through her online social & professional networking platform, Colors of STEM.