April 22, 2024

Accelerate your Machine Learning Computations with the Bitnami-Packaged JAX framework

A machine learning framework by definition, JAX can also be considered as an extensible system for transforming numerical functions. Learn how to get the most out of Bitnami-packaged JAX in this blog

This blog was authored by Fran De Paz Galan

With the rapid growth of machine learning (ML) applications, software developers are finding more and more value in the open source software world for tasks such as model training, selection, and evaluation. To cater to our customers and users, we are steadily adding ML-centric open source software tools to the open source Bitnami Application Catalog, as well as its enterprise version—VMware Tanzu Application Catalog.

One such Ml-focused tool in our catalog is JAX. A machine learning framework by definition, JAX can also be considered as an extensible system for transforming numerical functions. It brings together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). Also designed to mirror the structure and workflow of NumPy as closely as possible, it can work with various existing frameworks such as TensorFlow and PyTorch.

In this blogpost, you will learn how to get the Bitnami-packaged JAX from Bitnami Application Catalog and Tanzu Application Catalog, deploy the container and set up a dev environment. We will also review a couple of practical examples by using some basic Python scripts.

Assumptions and Prerequisites

Deploying the community edition of Bitnami-packaged JAX container from Bitnami Application Catalog

JAX is a development-focused application so the first step will be to set up the development environment where your Python scripts will be created. The directory will be hosted locally and directly mounted into the container as a volume. This will allow you to easily create and modify your scripts from outside the container using your preferred IDE.

$ mkdir ~/jax_scripts && cd ~/jax_scripts

 

Let’s assume we already have a script we want to run inside our container. The easiest way to launch the container will be by executing the following command:

$ docker run --rm --name jax -v ./:/app bitnami/jax:latest my_jax_script.py

This will start a new bitnami/jax container, run my_jax_script.py using Python, and then remove the container.

Deploying the enterprise version of Bitnami-packaged JAX container from Tanzu Application Catalog

The following instructions describe how to navigate to the Tanzu Application Catalog and deploy the JAX container image. This blog post uses Photon OS as the recommended base OS image in Tanzu Application Catalog but there are many others to choose from.

  1. Navigate to app-catalog.vmware.com and sign in to your catalog with your VMware account.
  2. In the My Applications section, search for the JAX container and click Details. On the next screen, you will find the instructions for deploying the container. Make sure that your Docker engine is up and running.

  1. Run the commands you will find in the “Consure your Container image” section.

Now that we understand how to use the bitnami/jax image, let’s see a couple of examples to start exploring JAX’s capabilities.

Basic example: Running a Python script within the JAX container

In this first example, we will run a Python script to demonstrate how easy it is to run JAX in the bitnami/jax container. The following code will make use of jax.vmap, one of JAX’s transformations, and jax.numpy, its NumPy mirror-like API:

 

$ cat ~/jax_scripts/my_jax_script_1.py
import jax.numpy as jnp
from jax import random
from jax import vmap
key = random.key(1993)
matrix_1 = random.normal(key, (19, 100))
matrix_2 = random.normal(key, (3, 100))
def apply_matrix(m):
 return jnp.dot(matrix_1, m)
print('Vectorised function using vmap')
res = vmap(apply_matrix)(matrix_2).block_until_ready()
print(res)

 

Now that the container is using a mounted volume, you’ll just need to create your script in the host folder. After that, you can access your running bitnami/jax container and run the script as shown below:

$ docker run -it -d --entrypoint bash --name jax -v ./:/app bitnami/jax:latest
79b650640620b5d39b324806c25c5f99c7c5bb53f92614f0ef09b6f6a9d4a568

$ docker exec -it 79b650 bash

$ jax@56c55ce909b5:/app$ python my_jax_script_1.py
Vectorised function using vmap
[[ 12.672197     5.573652    -0.610461     9.84283     -0.701779
  14.255789    -5.0340986    4.580534     1.6429727  -23.880184
   1.5640947    4.091061    -3.4076087    7.6722975  -13.565388
   1.9004642    0.67286134  -1.8921701   -5.034804  ]
...

Advanced example: Training a neural network with a PyTorch dataloader

JAX deliberately does not provide any built-in data loaders or datasets, instead, it uses publicly available datasets like Tensorflow’s or Pytorch’s. The following example, adapted from one supplied in the JAX repository, makes use of Pytorch’s data loading API to train a neural network:

 

## my_jax_script_2.py
# Skipped auxiliary imports, functions and logic
import time
import jax.numpy as jnp
from torch.utils import data
from torchvision.datasets import MNIST
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers = 0)
# Get the full train dataset (for checking accuracy while training)
train_images = jnp.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), - 1)
train_labels = one_hot(jnp.array(mnist_dataset.train_labels), n_targets)
# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), - 1), dtype=jnp.float32)
test_labels = one_hot(jnp.array(mnist_dataset_test.test_labels), n_targets)
for epoch in range(num_epochs):
 start_time = time.time()
 for x, y in training_generator:
   y = one_hot(y, n_targets)
   params = update(params, x, y)
 epoch_time = time.time() - start_time
 train_acc = accuracy(params, train_images, train_labels)
 test_acc = accuracy(params, test_images, test_labels)
 print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
 print("Training set accuracy {}".format(train_acc))
 print("Test set accuracy {}".format(test_acc))

As we can see, this example is using Pytorch MINST with JAX’s Numpy. To use Pytorch libraries, we’ll just need to install them in our bitnami/jax container before running the script:

$ docker exec -it 79b650 bash
$ jax@56c55ce909b5:/app$ pip install torch torchvision
$ jax@56c55ce909b5:/app$ python my_jax_script_2.py
Epoch 0 in 4.01 sec
Training set accuracy 0.9158333539962769
Test set accuracy 0.9196999669075012
...
Epoch 6 in 3.74 sec
Training set accuracy 0.9708166718482971
Test set accuracy 0.9651999473571777
Epoch 7 in 3.77 sec
Training set accuracy 0.9737333655357361
Test set accuracy 0.9669999480247498

Support and Resources

You can access the Bitnami package for JAX from the Bitnami GitHub repository, available in both community and enterprise versions via the Tanzu Application Catalog. Explore the differences between these catalogs by referring to our blog post.

Should you encounter issues with Bitnami community packages, please don’t hesitate to open an issue in the Bitnami Helm charts or containers GitHub repository. Additionally, if you wish to contribute to the catalog, you can submit a pull request. Our team will review it, offering guidance throughout the merging process for a successful collaboration.

If you are deploying the enterprise version of this package and experiencing any issues, please file a support request via the VMware Cloud Services Console.

For comprehensive guidance on integrating JAX into your code or processes, please consult the official JAX documentation.

If you are interested in learning more about the Tanzu Application Catalog in general, check out our product webpage, and additional resources. If you would like to get in touch, contact us.

Filter Tags

Tanzu VMware Tanzu Application Catalog Blog Feature Walk-through Operational Tutorial Intermediate