Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ


Choose your language

InfoQ Homepage Presentations ML in the Browser: Interactive Experiences with Tensorflow.js

ML in the Browser: Interactive Experiences with Tensorflow.js



Victor Dibia provides a friendly introduction to machine learning, covers concrete steps on how front-end developers can create their own ML models and deploy them as part of web applications. He discusses his experience building Handtrack.js - a library for prototyping real time hand tracking interactions in the browser.


Victor Dibia is a Research Engineer with Cloudera’s Fast Forward Labs. Prior to this, he was a Research Staff Member at the IBM TJ Watson Research Center, New York. His research interests are at the intersection of human computer interaction, computational social science, and applied AI.

About the conference

Software is changing the world. QCon empowers software development by facilitating the spread of knowledge and innovation in the developer community. A practitioner-driven conference, QCon is designed for technical team leads, architects, engineering directors, and project managers who influence innovation in their teams.


Dibia: This is a topic that I'm really excited about, taking something I really love, machine learning, and also blending that with my love for front-end and JavaScript development. Today, I just want to tell you all about how you can also start to do that. My name is Victor Dibia, I'm a Research Scientist, Research Engineer with Cloudera Fast Forward Labs. As with respect to my background, some of my work has cut across computer science, software engineering. I also have some formal training in human-computer interaction. Over the last four, five years, I spent a lot of my time with applied machine learning.

A few years ago, I was a Research Scientist at IBM's TJ Watson Research Lab at Yorktown. I spent a lot of my time there prototyping natural language and natural gesture interaction interfaces. More recently, I tried to spend some time contributing to the community a bit in machine learning and in JavaScript. I serve as the Developer Expert for machine learning within the Google Developer Expert program. Currently, I'm a Research Engineer at Cloudera Fast Forward Labs. That's a little bit about me. Just to also give you a little bit of a background about what we do at Fast Forward Labs, we like to think of ourselves as the bridge between academia and industry. As part of that process, every quarter, we would release a research report. It would go through this really comprehensive process of identifying topics from academia, that sweet spot of making the transition from academic research into stuff that our clients can actually use.

What we'll do is that we'll spend a couple of months researching these topics and then we'll write reports around them. In addition to that, we also build interactive prototypes that show a concrete demonstration of what these technologies could do. The slide you see right now is a prototype we built, we released it just a couple of months ago, something called ConvNet Playground. Essentially, if you think of the task of semantic search, given an image, show me all of the images that are most similar to this, it's a common problem. It turns out that you can solve it really efficiently using pre-trained models, but there's still a lot of decisions that the data scientist needs to make. What model do I use? Inception-v3, VGG16. What does all of that mean? Even when you select a model, do you use the entire model, or do you just take a small chunk of that model?

What this prototype does is that it takes results from computations done in about eight different models, using about eight different models, sub-models constructed from each model. It shows you what the result of search performance would look like if you used any of this. This is just a really simple example of what that interface looks like. You select an image, you perform your search. You can see metrics about search score. Not just that, you can select different models, and then you can look at graphs that help you compare all of this. If you are interested in this sort of thing, we have a platform called These are all open-source projects or open-source interfaces that you can actually play with. For the rest of our work, we have a subscription service that anyone can also join.

That's a bit of background on what I do. To today's agenda, which is the exciting part, here's what I'm going to go over. The first thing, I want to give an introduction to ML in the browser. I'll start with some terminology, just to make sure we all are on the same page. Then I'll give an overview of the TensorFlow.js library. It's got two APIs - the low-level API for linear algebra, and also the high-level layers API that you'd actually want to use to build complex neural networks. Then finally, I'll look at Handtrack.js. Handtrack.js is this JavaScript library for prototyping gesture recognition completely in the browser, which I developed. This is an Ops project. I'll go through how I actually built this thing and made it available.

Why Machine Learning in the Browser?

Why machine learning in the browser? Before we jump into that, let's review some really simple terminology. The first speaker covered a bit of this, but I just wanted to make sure that for those who weren't there at that talk, I just want to cover the basics. When we talk about artificial intelligence, it's a super broad field. It's really old, dates as far back as the 1950s. You can think of it as just efforts directed towards making machines intelligent, and there are many ways you can do that. A sub-aspect of that is machine learning, which is algorithms that enable machines independently learn from data. This idea is pretty powerful because there are a lot of problem spaces that as software engineers it's almost impossible to write concrete rules for.

If you take the classic example of how do we differentiate a cat from a dog, you want to do things like, how many legs does it have? It has four legs. It's either a cat or a dog. If it has pointy ears, it's a cat. That can really get complex if you have things like multiple orientations, occlusions, and all that sort of things. You want an algorithm that, rather than writing rules, you showed a bunch of examples and it independently learns to write patterns. Examples of these algorithms are things like decision trees, random forests, reinforcement learning, and neural networks, which brings us to the interesting part, deep learning. Deep learning is just a special case of machine learning. Here the main algorithm we use are neural networks. To define neural networks, I like to think of them as just a stack of computational units, where computation flows from one unit to another. I'll give an example.

Over the last few years, deep learning has become wildly popular. There are some reasons for that. One of them is that it's one of the only algorithms where your accuracy scales with data. The more data you have, the more likely your model is going to get more accurate, which is a really interesting concept. In addition, over the last 10 years, there's been, like, a lot of really great advances in neural network algorithms. You might have heard of things like ResNet, Inception. That's for image processing. You might have heard of things like BERT for language processing. All of these things have made neural networks really popular. The other interesting thing about this space is that these algorithms almost always have the best performance for any given task.

The next interesting thing that has made this deployment really interesting is GPUs. Researchers and industry practitioners have found ways to make neural networks and deploying algorithms run really fast in GPUs. For the rest of the talk, whenever I mention AI and machine learning, what I'm actually talking about is deep learning and the supervised form of deep learning.

To cover a few more terms that I'll use as the talk progresses, I just want to introduce some terms, things like inputs, weights, sum and bias, activation functions, and output. The diagram you see here is an example of the simplest neural network you ever see as typically called the perceptron. It's the simplest neural network you'll find. Everything you see here is really easy to understand. The first things you have here are your inputs, whatever it is that flows into the network. The next thing you have is that this input is multiplied by some values called weights. All of that is summed and some constant values added called the bias. Finally, you pass all of that through an activation function and then you get your output. The beautiful thing here is that if we select these weights very carefully, and we select the bias very carefully, and we select the right activation function, we can get this set of computation to represent or approximate functions that we are interested in. It turns out that this really simple neural network you have here could approximate things like an AND function, an OR function, or a XOR function. The next question in your mind is, how do we select these magical values, the weights, the bias, and the activation function?

For the weights and the bias, you identify this through the training process. I don't have enough time to go into much detail, but I think you hopefully get a bit of the idea. This is a very simple version, however, it's probably too simple to solve most of the problems we want to actually solve in real life. It turns out that just one very fast way to scale this up is to add a bunch of these things, to stack a lot more of these simple neural networks together. Here, the terminology that you should think of is, you have your regular input, for any set of computational units that are between the input and the output, you refer to those things as hidden layers. For any set of layers where all inputs from one layer are connected to all inputs from the next layer, you call those things dense layers. That's just a set of introductory terminology.

If we think about the languages for machine learning, typically, we'll see languages like Python, Julia, and MATLAB. Occasionally, you'll see some really brave and amazing individuals who are actually doing machine learning in these languages: Java, Go, C, and C++. One thing you'll notice about these languages - they're back-end programming languages, they're typically high-performance languages, and they work really well because they support things like multi-threading, they have really fast IO access. However, is this a real picture of the landscape of languages that developers actually use?

To learn more about that, I went to the 2019 Stack Overflow survey of 90,000 developers. They were asked this question, "What programming language do you use the most?" Here we see that JavaScript is the clear winner, so 67% of people actually focus on JavaScript. Similarly, they asked people, "If there's a language you don't currently use, which of these languages are you most interested? What do you want to actually learn?" Here again, we see a lot of interest in JavaScript. If we look at GitHub, the trend is pretty similar. For the last 10 years, JavaScript repositories have been the most popular repositories. There is this really interesting article from a couple years ago, where someone literally said that, "JavaScript is eating the entire world." The idea is that on the back-end, people are writing JavaScript server applications using the JSON. On the front-end, things are just like, "JavaScript has just taken over."

At this point, we know that people love JavaScript. Those who don't use it, they want to use it. They use it on the back-end, they use it on the front-end. The idea is, can we take JavaScript and can we merge that with neural networks? Most of the time when I talk to people, the first thing they're, "This is a bad idea." When you think about that, typically, they have a bunch of questions. The first question is, "Can I really train a neural network in JavaScript? Can I express ideas like LSTMs, dense layers, convolutional layers? Can I do all of that, but only in JavaScript?" It turns out that, yes, you can. Your path to that is a TensorFlow.js library. It's a library for building and training machine learning models completely in JavaScript.

The next question that people would ask is related to speed. The first thing is, JavaScript, while it's a relatively fast language, it wasn't really designed for high-performance computing. In addition to that, neural networks tend to be a really compute-intensive operation, and so people wonder, "Can I actually have fast computation, but all expressed in JavaScript?" Again, the answer is yes. How do you do that? You do that again through TensorFlow.js. On the browser, TensorFlow.js provides acceleration using WebGL. WebGL is this really nice standard that allows accelerated graphics computing in the browser. What it means here, what's relevant here is that it does have access to a GPU. If you do have a GPU or graphics card on that machine, WebGL would let you take advantage of that for accelerated computing. TensorFlow.js takes advantage of that. You get really fast computational inference or training as you want in the browser. On the back-end, you could use TensorFlow with Node.js. It gets its acceleration by binding directly to the TensorFlow C API. Pretty much similar performance you get from Python, you get that from Node.js on the back-end.

The final question is, how much effort is this? For many teams, it's a make or break decision. We have a machine learning team, and they spend a lot of time building and training models intensive for Python. If you want to use JavaScript, what is the effort associated, which kind of integration that to our workflow? It turns out that not much effort. One of the interesting things is that the TensorFlow.js API is similar to the Keras API. Keras is just a framework for training neural network models. If you're familiar with Keras, it's an easy path or transition to actually start building models in TensorFlow.js. The other interesting thing is that if you have models that are already trained in TensorFlow, Python, TensorFlow.js provides a converter that allows you to take your pre-trained models and then convert that into a web model format that can be loaded into your JavaScript application. I'll get to that in a little bit more detail as the talk progresses.

At this point, I'm thinking, we know that you can do fast computation. We know that the amount of effort to move TensorFlow.js is not that much. Are there other benefits that you can get by actually using TensorFlow.js? Here are my top four reasons. The first has to do with privacy. Across the talks at the conference this year, it turns out that you probably have heard people talk about data privacy over and over again. With running machine learning models in the browser, you get the opportunity to offer a really strong notion of privacy. Most of the time, you find out that despite the best intentions of many software companies, despite their efforts to protect data, data breaches always occur. However, with TensorFlow.js, you can offer a different kind of privacy speech. You can say things like, "We don't see your data, and it never gets to our servers," as opposed to saying things like, "We see your data, but we promise to keep it and keep it safe."

For most software applications, in order to offer your user a service, you probably need to use their data. Many times, user data is sensitive and so how about we move the compute to the browser, perform all our computation, offer a service, and the data never gets to the server? The other interesting ideas around machine learning in the browser is related to distribution. At this point, I always like to tell a really simple story. About three years ago, I had a few friends who wanted to get into machine learning. Their first step was to install TensorFlow, the TensorFlow software library in Python. Many of them just never got into machine learning because it could not install TensorFlow. It's a true story and that used to be a really complex process. However, while things have gotten a lot better, you can do pip install TensorFlow now, you can do conda install TensorFlow.

When it comes to deploying machine learning applications to end-users, it can still be a really complex process. If you do this in the browser, this simplifies that whole workflow. There are no installs, no drivers, no dependency issues. It's as simple as going to a URL, and open a web page, and everything just works. In addition, if you will deploy your model hosted on NPM, you can actually have all the benefits of model hosting, versioning, and distribution that comes with distributing libraries through NPM.

The next has to do with latency. These days models can be optimized to run really fast, and they can run really fast in mobile and in the browser. In some use cases, it's actually faster to run your model in the browser, rather than send round trip request to a server and then render that request back to the user. In addition for resource constrained regions like Africa and East Asia, in these regions, you really cannot rely on the internet connectivity there. It's a much better user experience to download all of the model locally to the device and then offer a smooth user experience that doesn't depend on constant internet connection.

Then, finally, the browser is designed for interactive experiences and machine learning can supercharge that. I'll show a few examples in a moment. With TensorFlow.js, you can build models on the fly. You can use rich user data available in the browser camera. The camera sensor is possible. You can retrain existing models and you can also enable really dynamic behavior. There are applications for these in ML education, retail, advertising, arts, entertainment, and gaming.

Before I proceed, I want to give a few concrete business use cases of how some industries are already using TensorFlow.js. An example is Airbnb, something related to privacy-preserving sensitive content detection. As part of the user onboarding process, they will ask the user to upload an image. They have observed that, in some cases, users might upload images that contain sensitive content, like their driver's license and other sort of images. What they've done here is that they've put in a TensorFlow.js model. That ride-in browser will tell the user, "Your image contains sensitive content. We haven't seen it, but we can offer the service of telling that you likely have sensitive content, and you probably should use another photograph."

Another interesting library here is nsfw.js. It's a JavaScript library, and it lets you check if an image contains indecent content, nudity, and all of that, but right there in the browser without sending that image or that content to any back-end server.

Another example I like is a slide from the recently concluded TensorFlow World Conference, where a team, ModiFace had created a JavaScript application with a footprint of about 1.8 megabytes and it does virtual try-ons in real-time. This is all integrated into the WhatsApp mini chat. One thing you should notice is that it works really well. The hair color of the lady in the image is changed in real-time. It works as fast as 25 frames per second. This is a really fantastic example of optimization and performing machine learning in the browser in real-time.

Some of you might have seen a neural network playground which was released by the Google TensorFlow.js team, and just lets you interactively learn about how machine learning works with a few data sets and different problems.


Now that we have a pretty good idea of good examples of TensorFlow.js, when it works and why you might want to use it, let's dive a little bit more into how you actually get the stuff done. For that, we'll look through the TensorFlow.js API. What can you do with TensorFlow.js? The first thing you can do is offer models. It means you can compose your models, define how many layers you want, define things like the type of blocks within your model. You can train this model, test it, and perform inference, all in JavaScript. You can do this both on the back-end in Node.js, and you could do this in the browser. I think of this as something called the online flow.

The next thing you can do with TensorFlow.js is that you can import pre-trained models for inference, you train it with your favorite computer environment, GPU clusters, TPU clusters, you export your model, and then you import it into your JavaScript application. I like to think of this as the offline flow. You train offline, you import the model, and then you run inference. Then, finally, we have the hybrid flow, where you probably train offline. You import the model into your JavaScript application. Then, you can also use local data from your environment, and then fine-tune that model, and make it just a little better. That's more like the hybrid flow.

One thing that can be slightly confusing about TensorFlow.js is the fact that it runs in two different environments. The first is the browser front-end, it's a really resource-constrained environment. Second is Node.js, which is a complete server environment and you can do a lot more there. I'll talk a little bit about what these two environments look like. With the browser, you can import TensorFlow.js in two ways. The first is that you could import it using a script tag, just like you would do with, let's say, jQuery. You could also import it using NPM. You could do NPM install TensorFlow.js, and then include that in build frameworks and tools like React and Vue.js. Again, in the browser, acceleration is achieved using WebGL and tensors that are implemented share the programs. That way, if you have a GPU available on that device, all of the computation can be performed on the GPU.

The other deployment environment is a Node.js. You install our TensorFlow.js using NPM install. If you want just the CPU version, you use NPM install, tfjs-node. If you have a GPU, you do NPM install, tfjs-node-gpu. The only caveat here is that right now, similar to the Python API, it will only take advantage of CUDA enabled GPU cores. If your GPU is not an Nvidia or CUDA-enabled core, you probably can't use it. Then, for acceleration is just provides bindings to the low-level TensorFlow C binary, so that way you get fairly similar performance to what you get from Python.

This is an overall picture of the TensorFlow.js ecosystem. At the bottom, we have Node.js and the browser, which is the two main deployment environment. On top of that, we have the two APIs that it supports, the low-level Ops API and the Layers API. For the rest of the talk, I'll walk through three examples. The first is how you can use the low-level Ops API to fit a polynomial function. If you have the polynomial function, how can we create a low-level optimizer to learn what that function looks like? The second thing I'll talk about is how we can build a two-layer autoencoder, which is a type of neural network in the browser. While preparing for this talk, I spent some time building an interesting demo that kind of shows all this idea, which I'm really excited to show you guys. Then finally, how can we take a model that's been trained offline, and then import that, and then use that in JavaScript?

Tensorflow.js OPS API

With regards to the low-level Ops API, my suggestion is that you use with caution, and only use it if you really know what you're doing. Let's consider this use case. We have a polynomial function, f of x is equal to x squared plus bx plus c. In this case, we want to learn the parameters a, b, and c. How are we going to go ahead and implement that in TensorFlow.js? The first thing we want to do is we want to import a TensorFlow.js library. We want to also specify the variables. We're interested in learning the parameters and here we define them as tensors. We initialize them with the value of 0.1. Then we express the function: f of x is equal to ax squared plus bx plus c. This is what the API syntax looks like. The next thing we want to do is, given our data, we want to find ways to automatically identify a loss, which is the difference between what our model predicts and the actual ground truth label. We create a loss function, which is based on mean-squared error. We get the difference between the subtraction of the predictions and the actual label. We square it and compute the mean. Then, we create an optimizer, which is just a function to give some signal about how to update our parameters as training progresses. Then, we run all of these two through a given number of EPOCHS. The good thing is that each time we call optimizer that minimize, it kind of updates the parameters we're interested in learning. That's a very fast overview of the low-level Ops API.

Tensorflow.js Layers API

The next thing we're going to talk about really fast is the high-level Layers API. Two things – it's the recommended APIs for building neural networks and it's very similar in spirit to the Keras library. For those who haven't used Keras, Keras is this really well-designed API for composing and training neural networks. It gives a really nice way to think about the components in your network, and a really good way to express these components in the actual code.

To illustrate the API, I'd like us to walk through the process of building something called an autoencoder. An autoencoder is a neural network that has two parts. Typically, it's being used for dimensionality reduction. An use case is, imagine that we have our input and it's 15 variables. We want to compress that into just two variables that represent all of the inputs, 15 variables. There are two parts. The first part of the neural network is called an encoder. It takes the high dimensions 15. Its goal is to output two values, which is called the bottleneck. The other requirement here is that it should learn a meaningful compression such that we can also extract or compute the original inputs just from that bottleneck representation. That's exactly what the decoder part of the neural network does, it takes this bottleneck that has been generated by the encoder, and it learns to reconstruct the original inputs from that bottleneck.

One interesting thing is that this whole model, this autoencoder, has been applied for the task of anomaly detection. The goal here is that if we have some normal data, we can learn this mapping from inputs to bottlenecks, small dimension, and then from small dimension to output. If we train this on a normal data set, that means every time we feed in some new data, and we perform a decoding, we get an output, then we should get about the same thing. How is this relevant to anomaly detection? If we train in normal data, we get an output that's similar to the input. Whenever we have an anomaly, a data set that this model has never seen, if it goes into the network and we try to reconstruct the output, we'll get something that's really different. That's called high reconstruction error. Depending on the size of this error of reconstruction, we can then flag it off as anomalies or not.

The idea is, if we have normal data, we have, let's say, an error of 0.8. If we have abnormal data, we have an error of 0.2 or vice versa. We can set some kind of threshold, where we'll say, when the reconstruction error is beyond this level, then we flag this data as an anomaly. How do we express all of this in JavaScript? If you recall, we have an input layer, which is 15 units. We have a hidden layer. We have an input, we pass it to a dense layer that's 15 units. We pass that through another dense layer that's seven units, we have a bottleneck layer, which is two units, and then that composes our encoder and just other parts are the decoder.

To express this in code, we specify our inputs. On the left, if we're going to use TensorFlow Python in Keras, that's exactly how we'll do it. Input is equal shape. We specify the number of features. The interesting thing that it's almost a one on one conversion if we do that in TensorFlow.js. To specify our next layer, we have a 15 unit dense layer. This is exactly how we do it in Python, left. This is how we'll do it in JavaScript on the right. We can see it's almost a one on one mapping again. As opposed to just using Dense in Python, we have tf.layers.dense. We specify units and we specify activation function.

Similarly, our next layer is a 15 layer dense unit, very similar between Keras and Python. Then, finally, a bottleneck layer, which we refer to as Z here, we also add to our model. The other interesting thing is that in Python Keras, we can easily do something called a model.summary, and it prints out all of the configuration of a model. The exact same method is available in JavaScript and that way you can compare if what you have in JavaScript is really equivalent to what you've built, let's say, in Python and Keras. This is a very useful tool.

As part of preparing for this talk, I built something called an anomaly detection model. The data set here is examples of heart rate ECG signals. The idea is that we have some normal signals and then we also have a bunch of abnormal signals. The normal signals are in blue and the abnormal signals are in green. Here we have some labels for our trained data. We also have our test data where we don't have any labels. Then, the goal is that right in the browser, we can train a model, such that our test time, when we select a specific signal, it can tell us if this signal is actually normal or it's likely abnormal.

In addition to that, because the browser is super interactive, we can do things like plot histograms of the reconstruction error. What you see here is that abnormal data has a risk construction error that mostly between 1 and 1.04. A normal data has a reconstruction error that's between 0.95 and 1. If we put our threshold right at the 1 mark, then we can easily separate abnormal from normal. We can also visualize the parameters learned in the bottleneck layer. The goal is that we want to see that there are linearly separable differences between our normal data and our abnormal data.

With regards to speed, I did some benchmarking. How fast is inference in Python, in Node.js, and in the browser? How fast is training in Python, in Node.js, and in the browser? What we have here is that I trained 2,600 data points for a total cycle of 20 epochs. We can see that Keras Python has the best performance for training time. It's about 1.3 seconds. The next best performance comes from Node.js, which is Node, but JavaScript but using the C bindings. Then, finally, it takes the longest time in the browser. This is not completely unexpected.

The surprising interesting part was that when it was time for inference, I ran inference on 500 data points, we see that Keras Python back-end has a pretty ok runtime. TensorFlow.js Node takes even more time, but then the browser has the best performance. This is probably surprising to all of you, but there's a good explanation for this. I ran all of these benchmarks on my MacBook computer. It turns out that my MacBook computer actually has a GPU, a graphics card. However, it's not a CUDA-enabled graphics card. What that means is that back-end TensorFlow is unable to take advantage of that. In the browser, using WebGL, we're able to accelerate that computing, and that's why we have faster inference in the browser. Of course, if I had a CUDA-enabled machine, it's probably being much faster on the back-end.

Building Handtrack.js

This is the last section of this talk. Essentially, I'm going to walk you guys through Handtrack.js. Handtrack.js is pretty simple. Given an image, predict the location, the bounding box of all hands that exists in this image. Underneath, it uses a convolutional neural network for object detection. All of this is bundled into a JavaScript library that allows any developer to actually just import the library and right in the browser without any installation they can actually start tracking the location of hands in an image. It's also open-source.

What does that look like? The only thing you need to do is you need to import the library. It's hosted on jsdelivr. Turns out that if you have your model deployed in an NPM, you can get access to that. The jsdelivr, which is a content delivery network will automatically serve your assets. The only thing you need to do is you need to load the model. Once it's loaded, you can call the detect method and then it will return the JSON object that contains the location of hand in that image. This is what the results of that call looks like. You get bounding box x, y, the width and the height, name of the class, which is hand, and also some confidence values [inaudible 00:36:44], 0.5% confident is the hand at this location or 0.9% confident is the hand at this location.

How did I build this? It's a typical machine learning process. We start with collecting some data, we train the model. In this case, I use the TensorFlow object detection API. Then you convert that model into the web model format using the TensorFlow convert and then bundle that into a web application. Starting with data assembly, I started with a dataset called Egohands. This data set is about 4,600 images of human hands, from the egocentric viewpoint and so on that's put on a Google Glass video device, and had recorded a bunch of examples of people performing different activities while looking down towards the hands.

All I did was that I wrote some code that converted original labels, which were polygons, converted those into bounding boxes, and then exported that into the tf records format required by TensorFlow for training models. The first step is to convert your images into the TensorFlow records formats. For each image, there's some metadata, the list of bounding boxes and labels. I trained these using the TensorFlow object detection API. It's just a collection of code that just makes it a little easier to perform transfer learning for the task of object detection. It's an open-source project. I really recommend you to look it up if you're interested in object detection problems.

Because this was going to run in the browser, I selected Mobilenet SSD, which is one of the fastest models available for object detection. Once I had all of this trained, I trained this using a cloud GPU cluster. As I had all of this done, I exported the resulting model as a saved model. The next step was to find a way to get this model into TensorFlow.js. For that, I used the TensorFlow.js converter. To use that converter, all you need to do is to install it in Python, pip install TensorFlow.js. Then you can run the TensorFlow.js converter command from the command line. You need to supply things like the location of the model, the input format, the output format.

The good thing here is that it supports different TensorFlow formats. If you have your model trained in Keras, you can import that. If you also want to use models from TensorFlow Hub, you could also do that. At the end of this process, the exported model is about 18.5 megabytes. In web terms, this is really huge, but in neural network terms, this is really small. These are only trade-offs that you probably want to make. One of the things that the converter does is that it shards your models into files with a maximum size of 4.2 megabytes just to make it easier to build on the web. At the end of the process, I bundled all of this into an NPM library. It's hosted in NPM. The good part here is that for libraries, you host an NPM, you can do things like versioning, you can serve it over a CDN, jsdelivr. It's something that's definitely recommended.

Here I'll just show some examples of what people have built with Handtrack.js. This is an interactive exploration for looking through chemistry molecules. This is a library someone built called Jama.js, where they had used Handtrack.js to prototype pinches kind of interactions, but using the entire human hand. This is a really simple application I built, where you have a website open, and you can control a ping-pong game just by waving your hand in front of the camera. This is actually in CodePen, if you want to go there and pop it open, you can actually play this game.


Here's the part where I do the thing that you're not supposed to do - a live demo. By the way, you can always try it out, it's something that's out there. This is a simple React application. All that's happening here is, each time I click an image, the model looks at the content of the image, and it draws a bounding box around the location of hands. You could do things like vary the confidence threshold. What this just says is that each time there's a prediction, if the prediction is less than 0.5, if the confidence is less than 0.5, don't render that box. If you really move things down, it'll start to find hands where they're not. If you move it to the bottom, it finds hands where nothing exists. You can find that sweet spot where you can have the confident hands. You can do the same here. The thing you learn here is that machine learning models are probabilistic. There's always a confidence level attached. It's always a good idea to visually communicate that.

Let's make it a little more complex. Can we spin up video and track hands? You can actually see that the model is actually looking at the feed from my camera. In real-time, as I move my hands around, you can actually see it actually following and actually tracking that really well. A simple example is that you could actually connect this as a method of interaction. If I wave my hands, I could actually control this little game I have here. It all runs in real-time about 20 frames per second. You don't have to take my word for it. You could just go to that website and actually try it.

Another model or demo I want to show really fast, and this is something that I'll probably release within the next month or so, is a simple interactive experience where you could build an autoencoder. In this case, you can add a couple of additional layers, so actually here we have an encoder with 12 units, 10 units, 5 units, 3 units, and another 3 units. You can compile that model. All of this is happening in the browser. You can actually train it. If we click train, we can actually see that we are learning to reconstruct input data. We actually see that as optimization progresses, we can see the loss or the errors kind of go down.

The second interactive graph you have here is a graph of a mean-square error. What it says is that, as the model has begun to train, its reconstruction error for normal data is the histogram we see on the left, and its reconstruction error for normal data is the histogram we see on the right. Typically, if you set a threshold, that's right at this point, we can very easily distinguish real from abnormal data.

Once we have this model trained, we can go back to our data and say, for these data points I have here, tell me if it's normal or abnormal, for this data point I have here, tell me if it's normal or not normal. To give some additional background, the data I have here is from a data set called ECG 5000. It's just an example of about 3,000 discretized recorded heart rate signals. They're all labeled as either normal or having some kind of defect or issues. You could train and also encode it that learns to tell these things apart.

Some Challenges

Right now, I've talked about all the beautiful things. TensorFlow.js is amazing, it works really well, but there are some challenges. The first challenge you need to think about is memory management. TensorFlow.js implements tensors as WebGLTextures. The implication of this is that you need to explicitly manage your Tensors. If you don't delete them explicitly, your application just keeps on hogging memory. The second thing is that the browser is limited, it's single-threaded, and compute-heavy operations block the UI. You need to think about this all the time. There's device fragmentation. Some devices have CPU, some don't have GPUs. Sometimes, if your device goes into low power mode, it will disable GPU and so you can get wildly different performance.

Some good practices: Try to spend some time optimizing your model. It might mean that if you compose your model, you might want to stay away from some compute-heavy operations like full 2D convolutions and replace that, let's say, with depthwise convolutions. You want to remove post-processing from your model graph. You want to explore model compression, approaches like quantization, distillation, and pruning. You also want to look at AutoML or neural architecture search, because this can find really efficient, energy-efficient and fast architectures automatically.

Other good practices: Learn to use asynchronous methods, that way you kind of avoid blocking the UI thread completely. Visually communicate model load and inference latency. Wherever possible, communicate any uncertainty associated with predictions. The Chrome Profiler is your friend. That's one great way to see if your application is hogging up more and more memory. Then, for more examples on how to design human-centered AI products, I recommend the Google PAIR people plus AI guidebook.


To conclude, here are the key takeaways. Machine learning in the browser is really interesting. It's actually attractive if your use case has to do with privacy. If you want easy distribution, low latency, and interactivity, how'd you get there? You get there using TensorFlow.js. It provides an API for ML in JavaScript, and it supports these three interesting flows, online, offline, and hybrid. In addition to it, it can be fast, there's acceleration in the browser, and also bindings with the low-level C++ API. It's expressive and it actually plays well with the rest of the TensorFlow ecosystem. Then, finally, research advances in compression, distillation, pruning, these things make models, like, run really fast. In addition, there are some new standards that enabled accelerated computation in the web. I think these two things come together to make this space extremely exciting.

Questions and Answers

Participant 1: Thanks for a great talk. While it's pretty clear that there are obviously lots of benefits of inference phase with TensorFlow.js, at the same time, when you think about this online training mode and using TensorFlow for training in browser, do you feel like we're ready? What's your impression there about the limitations of confines of a browser? Where are we with this right now?

Dibia: For really large models, for large datasets, you probably do not want to train all of that in the browser. From my experience, and from looking at the industry right now, the most popular flow is what I refer to as the offline flow. What most people would do is that they'll train their models using their GPU clusters or whatever hardware acceleration is available to them. Once they have it at an accuracy level that they're really comfortable with, they spend more time doing optimization, essentially, bringing down the footprint of that model, then finally convert that and perform inference in the browser. I tend to see more of that.

Then, the next best thing to that is you train the model, you quantize it, you compress it, you optimize it, you import into the browser. Then maybe in the browser, you could use user data to fine-tune it. That's not an extremely expensive operation, but those are the two ways in which people are actually solving this problem or getting this done. In the demo I showed you, you could actually train the autoencoder from scratch. What I shared right now was about 2,000 data points. Each data point is 140 variables. It's not an extremely large dataset, but it's possible. If you're looking at images, things become a lot more complicated and definitely not recommended for training.


See more presentations with transcripts


Recorded at:

Jan 06, 2020