Overview
This is part two of a two-part blog series. If you haven’t read part 1, I reccomend you do so first to get caught up to this point. All the code referenced can be found in full functioning order here.
In part 1 we finetuned a ResNet18 model for classification of 7 labels and saved it as best.pt
.
Specifically in this post we will use our finetuned PyTorch model best.pt
, perform inference with it and export it for web and nodejs contexts using onnxruntime-web
.
Inference with Python
It’s worth testing this model in python first as it’s a quick way to qualitatively observe it, or just get a fast result and feel good.
First, remember we are using these labels:
labels_map = {
0: "bike",
1: "car",
2: "cat",
3: "dog",
4: "flower",
5: "horse",
6: "human",
}
Next, we will put some random images in a new folder data/inference
to test with. We can write code to perform transforms on the data and load it like so:
data_dir = "./data/inference"
# The image input size we need
input_size = 224
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_transforms = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
After setting up the transforms, PyTorch has a cool utility SubsetRandomSampler
which lets us randomly sample from a dataset.
def get_random_images(num):
data = datasets.ImageFolder(data_dir, transform=test_transforms)
indices = list(range(len(data)))
np.random.shuffle(indices)
idx = indices[:num]
from torch.utils.data.sampler import SubsetRandomSampler
sampler = SubsetRandomSampler(idx)
loader = torch.utils.data.DataLoader(
data,
sampler=sampler,
batch_size=num)
dataiter = iter(loader)
images, labels = next(dataiter)
return images, labels
The actual inference, involving passing it through the model is done like so (removing the ):
def predict_image(model, image):
image_tensor = test_transforms(image).float()
image_tensor = image_tensor.unsqueeze_(0)
input = image_tensor
input = input.to(device)
output = model(input)
index = output.data.cpu().numpy().argmax()
print(list(output.data.cpu()[0]))
print(index)
return index
We will use matplotlib
for the visualization piece for our qualitative benefit.
Finally, putting it together we can randomly sample images for inference from our folder like so:
model = torch.load('./best.pt')
model.eval()
to_pil = transforms.ToPILImage()
images, labels = get_random_images(5)
fig = plt.figure(figsize=(10, 10))
for ii in range(len(images)):
image = to_pil(images[ii])
index = predict_image(model, image)
label = labels_map[index]
sub = fig.add_subplot(1, len(images), ii+1)
sub.set_title("Actual: {0}{1}\nGuess: {2}{3}".format(
labels_map[int(labels[ii])], int(labels[ii]), label, index))
plt.axis('off')
plt.imshow(image, cmap="gray")
plt.show()
If you set it up correctly you might see something like:
Exporting the model to ONNX format
Great, you have a model, you know how to use it for inference within python. Now we want to export it to ONNX format so we can use it in other contexts. Specifically the browser.
Now we use onnxruntime
to export the model to ONNX format. The code is simple, so it is all included below.
The notable piece here is identifying the input tensor size (224x224x3) in this case. We also set the batch_size
dimension to be dynamic so we can pass in any batch size we want. This is important for the web context as we will be passing in a single image at a time.
import onnx
import torch
import onnxruntime as ort
import numpy as np
batch_size = 1
image_size = 224
model_name = "resnet18.onnx"
model = torch.load('./best.pt')
model.eval()
x = torch.randn(batch_size, 3, image_size, image_size, requires_grad=False)
torch_out = model(x)
# Export the model
torch.onnx.export(model,
x,
model_name,
export_params=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
})
onnx_model = onnx.load(model_name)
onnx.checker.check_model(onnx_model)
ort_sess = ort.InferenceSession(model_name)
outputs = ort_sess.run(None, {'input': x.numpy()})
Running this code produces the file: resnet18.onnx
which we will use in the next section.
Inference in NodeJS
Before we get to the browser I think it’s easier to see in a NodeJS context. This is because there is no html/css to deal with, and we can just focus on the JavaScript code.
Here is are the package.json
dependencies we are working with for this section:
"dependencies": {
"blueimp-load-image": "^5.16.0",
"canvas": "^2.11.2",
"ndarray": "^1.0.19",
"ndarray-ops": "^1.2.2",
"onnxruntime-node": "^1.14.0",
"save": "^2.9.0"
}
We will use blueimp-load-image
to load the image, canvas
to resize it, and onnxruntime-node
to run the model.
First, we will load the image and resize it to 224x224x3. We will also convert it to a ndarray
which is a format that onnxruntime-node
expects.
First, the imports:
const ort = require("onnxruntime-node");
const ImageLoader = require('./ImageLoader.js');
const ndarray = require("ndarray");
const ops = require("ndarray-ops")
const fs = require("fs")
const Tensor = ort.Tensor;
Then the code to load the image and resize it:
function preprocess(data, width, height) {
const dataFromImage = ndarray(new Float32Array(data), [width, height, 4]);
const dataProcessed = ndarray(new Float32Array(width * height * 3), [1, 3, height, width]);
// Normalize 0-255 to (-1)-1
ops.divseq(dataFromImage, 128.0);
ops.subseq(dataFromImage, 1.0);
// Realign imageData from [224*224*4] to the correct dimension [1*3*224*224].
ops.assign(dataProcessed.pick(0, 0, null, null), dataFromImage.pick(null, null, 2));
ops.assign(dataProcessed.pick(0, 1, null, null), dataFromImage.pick(null, null, 1));
ops.assign(dataProcessed.pick(0, 2, null, null), dataFromImage.pick(null, null, 0));
return dataProcessed.data;
}
async function getInputs(imageSize = 224){
// Load image.
const imageLoader = new ImageLoader(imageSize, imageSize);
const imageData = await imageLoader.getImageData('./data/inference/human/rider-12.jpg');
// Preprocess the image data to match input dimension requirement
const width = imageSize;
const height = imageSize;
const preprocessedData = preprocess(imageData.data, width, height);
const tensorB = new Tensor('float32',
preprocessedData,
[1, 3, imageSize, imageSize]
);
return tensorB;
}
Finally, we can run the model and get the output:
async function inference(){
// load the ONNX model file
const cpuSession = await ort.InferenceSession.create('./resnet18.onnx');
// generate model input
const inferenceInputs = await getInputs();
// execute the model
const output = await cpuSession.run({"input": inferenceInputs});
// prints the label of the class with the highest probability
console.log(labelsMap[indexOfMax(output.output.data)]);
}
inference();
Inference in the Web Browser
Here is the demo link for the web browser inference.
Now I’ll walk you through a minimal example with less bells and whistles.
First, we need to use browserify to bundle the code. This is because we are using ndarry
and ndarray-ops
which is a node module. We need to bundle it so it can be used in the browser.
$ npm install --global browserify
$ npm install ndarray ndarray-ops
Next create a file main.js
:
var ndarray = require("ndarray");
var ops = require("ndarray-ops");
global.window.ops = ops;
global.window.ndarray = ndarray;
Then we can bundle the code:
$ browserify main.js -o bundle.js
Finally, in our barebones index.html we can load the bundle and run the inference:
Our <head>
tag imports.
<head>
<title>ONNX Runtime JavaScript Browser Example</title>
<!-- import ONNXRuntime Web from CDN -->
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script src="./bundle.js"></script>
</head>
In our <body>
we have the following <script>
code:
Notice we use the same preprocess
function from our nodejs example. Most of the code here is just loading the image in the format we want (from a hard coded url in this case). Then we run the model and get the output.
I won’t go over the image loading code as it’s outside the scope of this blog post, but I’m sure if you’ve followed along this far it won’t be too hard to figure out.
<script>
function getBase64FromImageUrl(url) {
const img = new Image();
img.setAttribute('crossOrigin', 'anonymous');
return new Promise((resolve, reject) => {
img.onload = function () {
const canvas = document.createElement("canvas");
canvas.width = this.width;
canvas.height = this.height;
const ctx = canvas.getContext("2d");
ctx.drawImage(this, 0, 0);
resolve(canvas.toDataURL("image/png"))
};
img.src = url;
});
}
const labelsMap = [
"bike",
"car",
"cat",
"dog",
"flower",
"horse",
"human",
]
function indexOfMax(arr) {
if (arr.length === 0) {
return -1;
}
const max = arr[0];
const maxIndex = 0;
for (let i = 1; i < arr.length; i++) {
if (arr[i] > max) {
maxIndex = i;
max = arr[i];
}
}
return maxIndex;
}
function getFloat32Array(base64Str) {
const b = atob(base64Str)
let byteNumbers = new Array(b.length);
for (let i = 0; i < b.length; i++) {
byteNumbers[i] = b.charCodeAt(i);
}
return byteNumbers;
}
function preprocess(data, width, height) {
const dataFromImage = ndarray(new Float32Array(data), [width, height, 4]);
const dataProcessed = ndarray(new Float32Array(width * height * 3), [1, 3, height, width]);
// Normalize 0-255 to (-1)-1
ops.divseq(dataFromImage, 128.0);
ops.subseq(dataFromImage, 1.0);
// Realign imageData from [224*224*4] to the correct dimension [1*3*224*224].
ops.assign(dataProcessed.pick(0, 0, null, null), dataFromImage.pick(null, null, 2));
ops.assign(dataProcessed.pick(0, 1, null, null), dataFromImage.pick(null, null, 1));
ops.assign(dataProcessed.pick(0, 2, null, null), dataFromImage.pick(null, null, 0));
return dataProcessed.data;
};
// use an async context to call onnxruntime functions.
async function main() {
try {
let data = await getBase64FromImageUrl('./data/data/flowers/0001.png');
data = getFloat32Array(data.replace(/^data:image\/(png|jpg);base64,/, ""));
// create a new session and load the specific model.
const session = await ort.InferenceSession.create('./resnet18.onnx');
// prepare feeds. use model input names as keys.
const feeds = { input: new ort.Tensor(new Float32Array(preprocess(data, 224, 224)), [1, 3, 224, 224]) };
// feed inputs and run
const results = await session.run(feeds);
// read from results
const dataC = results.output.data;
document.write(`Label: ${labelsMap[indexOfMax(dataC)]}`);
} catch (e) {
document.write(`failed to inference ONNX model: ${e}.`);
}
}
main();
</script>
This isn’t meant to be a pretty or compelling example, just barebones to get something working in the browser. To test it locally I did:
$ python3 -m http.server
Then I went to http://localhost:8000/
. You should see the label of the image printed on the screen after a few seconds (no loading spinner, very minimal example 😂).
Thanks everyone. Feel free to contact me through the contact form on my website if you have any questions or comments.
My next blog post will deal with deploying ML models ot a different context - a custom FPGA-based hardware accelerator, stay tuned.