How to finetune a model then deploy it to the web: Part 2 - 'Inference with Python and JavaScript'

illustrations illustrations illustrations illustrations illustrations illustrations

How to finetune a model then deploy it to the web: Part 2 - 'Inference with Python and JavaScript'

Published on May 05, 2023 by Dominik Kaukinen

post-thumb

Overview

This is part two of a two-part blog series. If you haven’t read part 1, I reccomend you do so first to get caught up to this point. All the code referenced can be found in full functioning order here.

In part 1 we finetuned a ResNet18 model for classification of 7 labels and saved it as best.pt.

Specifically in this post we will use our finetuned PyTorch model best.pt, perform inference with it and export it for web and nodejs contexts using onnxruntime-web.

Inference with Python

It’s worth testing this model in python first as it’s a quick way to qualitatively observe it, or just get a fast result and feel good.

First, remember we are using these labels:

labels_map = {
    0: "bike",
    1: "car",
    2: "cat",
    3: "dog",
    4: "flower",
    5: "horse",
    6: "human",
}

Next, we will put some random images in a new folder data/inference to test with. We can write code to perform transforms on the data and load it like so:

data_dir = "./data/inference"

# The image input size we need
input_size = 224

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

After setting up the transforms, PyTorch has a cool utility SubsetRandomSampler which lets us randomly sample from a dataset.

    def get_random_images(num):
        data = datasets.ImageFolder(data_dir, transform=test_transforms)
        indices = list(range(len(data)))
        np.random.shuffle(indices)
        idx = indices[:num]
        from torch.utils.data.sampler import SubsetRandomSampler
        sampler = SubsetRandomSampler(idx)
        loader = torch.utils.data.DataLoader(
            data,
            sampler=sampler,
            batch_size=num)
        dataiter = iter(loader)
        images, labels = next(dataiter)
        return images, labels

The actual inference, involving passing it through the model is done like so (removing the ):

def predict_image(model, image):
    image_tensor = test_transforms(image).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input = image_tensor
    input = input.to(device)
    output = model(input)
    index = output.data.cpu().numpy().argmax()
    print(list(output.data.cpu()[0]))
    print(index)
    return index

We will use matplotlib for the visualization piece for our qualitative benefit.

Finally, putting it together we can randomly sample images for inference from our folder like so:

    model = torch.load('./best.pt')
    model.eval()

    to_pil = transforms.ToPILImage()
    images, labels = get_random_images(5)
    fig = plt.figure(figsize=(10, 10))
    for ii in range(len(images)):
        image = to_pil(images[ii])
        index = predict_image(model, image)
        label = labels_map[index]
        sub = fig.add_subplot(1, len(images), ii+1)
        sub.set_title("Actual: {0}{1}\nGuess: {2}{3}".format(
            labels_map[int(labels[ii])], int(labels[ii]), label, index))
        plt.axis('off')
        plt.imshow(image, cmap="gray")
    plt.show()

If you set it up correctly you might see something like:

Inference with Python

Exporting the model to ONNX format

Great, you have a model, you know how to use it for inference within python. Now we want to export it to ONNX format so we can use it in other contexts. Specifically the browser.

Now we use onnxruntime to export the model to ONNX format. The code is simple, so it is all included below.

The notable piece here is identifying the input tensor size (224x224x3) in this case. We also set the batch_size dimension to be dynamic so we can pass in any batch size we want. This is important for the web context as we will be passing in a single image at a time.


    import onnx
    import torch
    import onnxruntime as ort
    import numpy as np

    batch_size = 1
    image_size = 224
    model_name = "resnet18.onnx"

    model = torch.load('./best.pt')
    model.eval()

    x = torch.randn(batch_size, 3, image_size, image_size, requires_grad=False)
    torch_out = model(x)

# Export the model
    torch.onnx.export(model,
                    x,
                    model_name,
                    export_params=True,
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={
                        'input': {0: 'batch_size'},
                        'output': {0: 'batch_size'}
                    })

    onnx_model = onnx.load(model_name)
    onnx.checker.check_model(onnx_model)

    ort_sess = ort.InferenceSession(model_name)
    outputs = ort_sess.run(None, {'input': x.numpy()})

Running this code produces the file: resnet18.onnx which we will use in the next section.

Inference in NodeJS

Before we get to the browser I think it’s easier to see in a NodeJS context. This is because there is no html/css to deal with, and we can just focus on the JavaScript code.

Here is are the package.json dependencies we are working with for this section:

"dependencies": {
    "blueimp-load-image": "^5.16.0",
    "canvas": "^2.11.2",
    "ndarray": "^1.0.19",
    "ndarray-ops": "^1.2.2",
    "onnxruntime-node": "^1.14.0",
    "save": "^2.9.0"
  }

We will use blueimp-load-image to load the image, canvas to resize it, and onnxruntime-node to run the model.

First, we will load the image and resize it to 224x224x3. We will also convert it to a ndarray which is a format that onnxruntime-node expects.

First, the imports:

    const ort = require("onnxruntime-node");
    const ImageLoader = require('./ImageLoader.js');
    const ndarray = require("ndarray");
    const ops = require("ndarray-ops")
    const fs = require("fs")
    const Tensor = ort.Tensor;

Then the code to load the image and resize it:


    function preprocess(data, width, height) {
        const dataFromImage = ndarray(new Float32Array(data), [width, height, 4]);
        const dataProcessed = ndarray(new Float32Array(width * height * 3), [1, 3, height, width]);

        // Normalize 0-255 to (-1)-1
        ops.divseq(dataFromImage, 128.0);
        ops.subseq(dataFromImage, 1.0);

        // Realign imageData from [224*224*4] to the correct dimension [1*3*224*224].
        ops.assign(dataProcessed.pick(0, 0, null, null), dataFromImage.pick(null, null, 2));
        ops.assign(dataProcessed.pick(0, 1, null, null), dataFromImage.pick(null, null, 1));
        ops.assign(dataProcessed.pick(0, 2, null, null), dataFromImage.pick(null, null, 0));

        return dataProcessed.data;
    }

    async function getInputs(imageSize = 224){
        // Load image.
        const imageLoader = new ImageLoader(imageSize, imageSize);
        const imageData = await imageLoader.getImageData('./data/inference/human/rider-12.jpg');
        // Preprocess the image data to match input dimension requirement
        const width = imageSize;
        const height = imageSize;
        const preprocessedData = preprocess(imageData.data, width, height);
        
        const tensorB = new Tensor('float32', 
            preprocessedData, 
            [1, 3, imageSize, imageSize]
        );
        return tensorB;
    }

Finally, we can run the model and get the output:


    async function inference(){

        // load the ONNX model file
        const cpuSession = await ort.InferenceSession.create('./resnet18.onnx');

        // generate model input
        const inferenceInputs = await getInputs();

        // execute the model
        const output = await cpuSession.run({"input": inferenceInputs});

        // prints the label of the class with the highest probability
        console.log(labelsMap[indexOfMax(output.output.data)]);
    }

    inference();

Inference in the Web Browser

Here is the demo link for the web browser inference.

Now I’ll walk you through a minimal example with less bells and whistles.

First, we need to use browserify to bundle the code. This is because we are using ndarry and ndarray-ops which is a node module. We need to bundle it so it can be used in the browser.


    $ npm install --global browserify
    $ npm install ndarray ndarray-ops

Next create a file main.js:

var ndarray = require("ndarray");
var ops = require("ndarray-ops");

global.window.ops = ops;
global.window.ndarray = ndarray;

Then we can bundle the code:


    $ browserify main.js -o bundle.js

Finally, in our barebones index.html we can load the bundle and run the inference:

Our <head> tag imports.

<head>
    <title>ONNX Runtime JavaScript Browser Example</title>
    <!-- import ONNXRuntime Web from CDN -->
    <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
    <script src="./bundle.js"></script>
</head>

In our <body> we have the following <script> code:

Notice we use the same preprocess function from our nodejs example. Most of the code here is just loading the image in the format we want (from a hard coded url in this case). Then we run the model and get the output.

I won’t go over the image loading code as it’s outside the scope of this blog post, but I’m sure if you’ve followed along this far it won’t be too hard to figure out.

 <script>
        function getBase64FromImageUrl(url) {
            const img = new Image();

            img.setAttribute('crossOrigin', 'anonymous');

            return new Promise((resolve, reject) => {

                img.onload = function () {
                    const canvas = document.createElement("canvas");
                    canvas.width = this.width;
                    canvas.height = this.height;

                    const ctx = canvas.getContext("2d");
                    ctx.drawImage(this, 0, 0);
                    resolve(canvas.toDataURL("image/png"))
                };

                img.src = url;
            });
        }

        const labelsMap = [
            "bike",
            "car",
            "cat",
            "dog",
            "flower",
            "horse",
            "human",
        ]

        function indexOfMax(arr) {
            if (arr.length === 0) {
                return -1;
            }

            const max = arr[0];
            const maxIndex = 0;

            for (let i = 1; i < arr.length; i++) {
                if (arr[i] > max) {
                    maxIndex = i;
                    max = arr[i];
                }
            }

            return maxIndex;
        }

        function getFloat32Array(base64Str) {
            const b = atob(base64Str)
            let byteNumbers = new Array(b.length);
            for (let i = 0; i < b.length; i++) {
                byteNumbers[i] = b.charCodeAt(i);
            }
            return byteNumbers;
        }

        function preprocess(data, width, height) {
            const dataFromImage = ndarray(new Float32Array(data), [width, height, 4]);
            const dataProcessed = ndarray(new Float32Array(width * height * 3), [1, 3, height, width]);
            // Normalize 0-255 to (-1)-1
            ops.divseq(dataFromImage, 128.0);
            ops.subseq(dataFromImage, 1.0);

            // Realign imageData from [224*224*4] to the correct dimension [1*3*224*224].
            ops.assign(dataProcessed.pick(0, 0, null, null), dataFromImage.pick(null, null, 2));
            ops.assign(dataProcessed.pick(0, 1, null, null), dataFromImage.pick(null, null, 1));
            ops.assign(dataProcessed.pick(0, 2, null, null), dataFromImage.pick(null, null, 0));

            return dataProcessed.data;
        };

        // use an async context to call onnxruntime functions.
        async function main() {
            try {
                let data = await getBase64FromImageUrl('./data/data/flowers/0001.png');
                data = getFloat32Array(data.replace(/^data:image\/(png|jpg);base64,/, ""));
                // create a new session and load the specific model.
                const session = await ort.InferenceSession.create('./resnet18.onnx');
                // prepare feeds. use model input names as keys.
                const feeds = { input: new ort.Tensor(new Float32Array(preprocess(data, 224, 224)), [1, 3, 224, 224]) };
                // feed inputs and run
                const results = await session.run(feeds);
                // read from results
                const dataC = results.output.data;
                document.write(`Label: ${labelsMap[indexOfMax(dataC)]}`);
            } catch (e) {
                document.write(`failed to inference ONNX model: ${e}.`);
            }
        }

        main();
    </script>

This isn’t meant to be a pretty or compelling example, just barebones to get something working in the browser. To test it locally I did:

    
        $ python3 -m http.server

Then I went to http://localhost:8000/. You should see the label of the image printed on the screen after a few seconds (no loading spinner, very minimal example 😂).

Thanks everyone. Feel free to contact me through the contact form on my website if you have any questions or comments.

My next blog post will deal with deploying ML models ot a different context - a custom FPGA-based hardware accelerator, stay tuned.