NumPy-like ML Libraries: JAX (by Google) & Theano (Legacy Research Library)

Modern machine learning relies on fast numerical computation, automatic differentiation, and efficient use of GPUs/TPUs. While NumPy remains the foundation for numerical Python programming, it lacks several advanced capabilities needed for training today’s deep learning models.

This gap led to the creation of specialized ML libraries that feel like NumPy but go far beyond it.

Two significant libraries in this category are:

  • JAX — Google’s high-performance numerical computing & ML library

  • Theano — one of the earliest deep learning libraries (now deprecated but historically important)

Understanding these helps students learn how modern ML frameworks evolved, how computation graphs work, and how GPU acceleration changed the ML landscape.

Let’s explore both in detail.


 1. JAX: Google’s High-Performance ML Library

JAX is a high-performance numerical computing library developed by Google. It offers a syntax almost identical to NumPy but adds powerful capabilities such as:

  • Auto-differentiation

  • GPU/TPU support

  • Just-In-Time (JIT) compilation

  • Parallel execution

  • Vectorization

JAX is especially popular in research labs, advanced ML teams, and scientific computing environments.


 1.1 Key Features of JAX

 1. NumPy-Like Syntax

You can write JAX code like NumPy, making it very easy to learn.

import jax.numpy as jnp
jnp.array([1, 2, 3])

 2. Automatic Differentiation (grad)

This is extremely useful for ML model training.

from jax import grad
def f(x):
    return x**2 + 3
grad(f)(2.0)

 3. JIT Compilation (jit)

Your Python functions become optimized machine code using XLA (Google’s compiler).

from jax import jit
@jit
def compute(x):
    return x * x

 4. Vectorization (vmap)

Vectorize functions without writing loops.


 5. Parallelism (pmap)

Run your code across multiple GPUs/TPUs automatically.


1.2 How JAX Works Internally

JAX transforms Python functions → an intermediate representation → compiles them using XLA → executes on CPU, GPU, or TPU.

This unlocks high-speed computation, making JAX ideal for:

  • Deep learning research

  • Reinforcement learning

  • Simulation & scientific computing

  • Differentiable programming


1.3 Where JAX Is Used

  • Google Research

  • DeepMind

  • Academic ML papers

  • Scientific simulations

  • High-performance ML systems


2. Theano: The Legacy Deep Learning Library

Before TensorFlow and PyTorch dominated the ML world, Theano pioneered deep learning computation.
Created by the University of Montreal, it introduced many foundational concepts in ML frameworks.

Though officially deprecated, it remains important historically and still appears in certain academic environments.


2.1 Key Features of Theano

 1. Symbolic Computation

You define expressions symbolically, and Theano compiles them into optimized code.

 2. Automatic Differentiation

Theano was one of the first Python libraries to support this.

 3. GPU Acceleration

Theano enabled researchers to use GPUs for deep learning before it was mainstream.

 4. Graph-Based Execution

Theano used computational graphs, which influenced TensorFlow 1.x and other frameworks.


2.2 Why Theano Was Important

  • It shaped early neural network research

  • Provided the foundation for frameworks like Keras, Lasagne, and Blocks

  • Introduced optimization techniques still used today

Between 2010–2017, it was the top choice for deep learning research.


2.3 Why Theano Was Deprecated

  • High competition from Google (TensorFlow) and Facebook (PyTorch)

  • Maintenance cost was too high for a research lab

  • Decline in community contributions

While not used widely today, understanding Theano gives students insight into how ML libraries evolved.


3. JAX vs Theano — Comparison

FeatureJAXTheano
StatusActiveDeprecated
SyntaxNumPy-likeSymbolic computation
Device SupportCPU, GPU, TPUCPU, GPU
DifferentiationAuto-diff (grad)Symbolic diff
ExecutionJIT (XLA)Graph compiler
Best ForResearch, advanced MLLegacy academic research

4. Code Comparison (Simple Example)

 Using JAX

import jax.numpy as jnp
from jax import grad

def f(x):
    return x**2 + 3

print(grad(f)(2.0))

 Using Theano

import theano
import theano.tensor as T

x = T.dscalar('x')
f = x**2 + 3
f_grad = T.grad(f, x)

gradient_fn = theano.function([x], f_grad)
print(gradient_fn(2.0))

JAX is cleaner and more modern.
Theano gives a deeper look into symbolic graph programming.


5. Why Students Should Learn About These Libraries

  •  Learn how auto-differentiation works
  •  Understand the history of ML frameworks
  •  Learn modern high-speed computation with JAX
  •  Build intuition about computation graphs (from Theano)
  •  Improve foundations before moving to PyTorch/TensorFlow

This combination builds strong theoretical + practical ML understanding.

JAX and Theano represent two generations of ML frameworks:

  • Theano pioneered deep learning computation

  • JAX represents the future of high-performance ML

Studying both gives students a unique advantage:
the ability to understand how modern AI libraries evolved, how they work, and how to apply them effectively.

Happy Learning!

Leave a Comment

Your email address will not be published. Required fields are marked *