Lab 11: Meet the TPU

Prologue: Logistics

Pair Programming

Like Lab 9 and Lab 10, you’ll be working in teams of two for this lab as well.

Due Dates

For this lab, you’ll be turning in the following deliverables:

Starter Code

You can get the starter code for this lab by cloning the lab repository:

git clone git@github.com:accelerated-computing-class/lab11.git

You can run the Python code using uvx telerun like usual.

Goals for This Lab

Unlike in previous labs, where we programmed a single NVIDIA GPU, in this lab we will be programming a 2x2 mesh of Google TPU v5p devices. This means we will now have to think carefully about networking and inter-device communication. So, what do we have to work with?

Interconnect Topology

In this lab, we will be programming a scale-up domain of multiple devices connected by an inter-chip interconnect (“ICI”) fabric.

In our particular setup, we have 4 TPU v5p devices arranged in a 2x2 grid, with each device connected to one horizontal neighbor and one vertical neighbor. We can think of this topology as a one-dimensional ring of 4 devices, and label our devices clockwise from the top-left with indices 0, 1, 2, 3 to mark their positions in the ring:

In real-world use cases, the TPU v5p platform supports connecting much larger numbers of chips together, reaching scale-up domain sizes as large as thousands of chips. In addition to supporting very large scale-up domains, TPU v5p also supports connecting chips together in a 3D mesh topology, allowing each device to connect to up to 6 neighboring devices, as opposed to the 2 neighbors per device we use in this lab.

Each TPU v5p ICI link supports up to ~92 GB/s of bandwidth in each direction, with a per-hop latency of roughly ~1 μs. Bandwidth is not fungible between links or between directions.

On-Chip Resources

Inside each TPU v5p, we have two identical cores. Each core is a single-threaded programmable processor, and executes sequences of instructions in a proprietary VLIW-flavored ISA. Each TPU core is roughly analogous to a single core on a CPU, or to a single warp scheduler on a GPU—it has an SRAM scratchpad, a register file, dynamic control logic, and functional units for integer math, floating point math, and other operations.

The TPU makes up for its low degree of multi-core parallelism by implementing a very high degree of SIMD parallelism. For 32-bit floating-point data, the TPU natively operates in terms of 1024-element-wide vectors. Each 1024-wide vector is usually thought of as a two-dimensional array of shape (8, 128), which Google calls a “tile.”

Compared to CPUs or GPUs, one especially notable feature of TPUs is that each TPU core possesses a relatively large, high-bandwidth SRAM scratchpad called “VMEM” (short for “Vector Memory”). On our TPU v5p, VMEM is 64 MiB per core. For comparison, an H100 GPU has only 0.3% as much shared memory per SM (227 KiB vs 64 MiB) and ~23% as much total SRAM scratchpad capacity per chip (30 MiB vs 128 MiB).

Although we will mostly be using VMEM for Part 1 of this lab, it’s worth noting that each of our TPU v5p devices also has 95 GB of attached HBM2e memory shared between the two cores. By contrast to a GPU’s complex memory hierarchy, HBM and VMEM are the only two significant general-purpose memory resources on a TPU; the TPU has no L2 cache, L1 cache, or high-capacity register file like one would find on an NVIDIA GPU.

Further reading: For more information on the hardware architecture of TPUs, the free online book “How to Scale Your Model” is an excellent resource.

Part 1: Implementing Collectives in Pallas

In Part 1 of this lab, we will implement high-performance TPU kernels to carry out two fundamental collective operations: reduce-scatter and all-gather. These two operations can be composed together to implement all-reduce.

Background: Pallas and TPUs

In 2025, there are two main ways to program TPUs:

In this lab, we will be exclusively programming in Pallas. You will not need to interact with JAX at any point (except insofar as Pallas shares some library infrastructure with JAX).

Coming from CUDA, Pallas is familiar to us in many ways:

However, Pallas differs from CUDA in two key ways: it uses a tracing-based compilation model and an array-based programming abstraction.

Tracing-Based Design

The way you define a Pallas program is by “tracing” a Python program. This means that execution happens in two phases:

  1. First, you run your Python program on special “traced”/“symbolic” objects representing values which will exist at runtime in your kernel, but which have no concrete value yet. When your Python code performs operations on these traced values, those operations emit instructions which are recorded by Pallas’s tracing infrastructure, but no code actually executes on the TPU.

  2. Second, once tracing is complete, Pallas compiles a kernel from the emitted instructions. You can then run this kernel on the TPU as many times as you like, with different inputs each time.

Pallas’s tracing-based architecture has a few notable implications:

You can read more about tracing in JAX (which works similarly to Pallas) here.

Array-Based Design

Unlike CUDA, Pallas has convenient built-in support for performing bulk operations on multi-dimensional arrays. In fact, multi-dimensional arrays are the primary abstraction Pallas provides for working with numerical data. Arrays in Pallas come in two flavors:

When you load from a Pallas array ref, what you get is an array value. Similarly, when you store to a Pallas array ref, you must provide an array value.

The Pallas compiler automatically exploits SIMD and VLIW parallelism on your behalf as part of its array-based programming model. For example, if x and y are both traced Pallas array values of shape (2048, 1024) with elements of type fp32, then you can compute their elementwise sum simply by writing the expression x + y, and the compiler will automatically decompose that bulk array operation into a sequence of individual vector instructions, packed efficiently into VLIW bundles.

With a few exceptions, the Pallas compiler generally requires that shape information for all arrays must be known at trace time. This imposes some restrictions on what you are able to do with arrays, but it also means that your own Python code can query the concrete .shape of most Pallas array refs or array values at trace time.

Warm-Up: exchange_with_neighbor

With all that preamble out of the way, let’s write our first Pallas program!

Since our goal for this lab is to make use of ICI, we’ll start by writing a simple program that sends data between multiple devices. Our objective in this section will be to implement an “exchange with neighbor” operation, which should copy arrays between devices according to the following pattern:

To do this, you’ll need to fill in the body of the function exchange_with_neighbor_pallas_kernel. A few notes on the execution model of this function:

In our kernel, we’re going to need to transfer data between devices. To do this, we can use Pallas’s remote-DMA (“RDMA”) functionality. The starter code provides the following convenience functions for working with RDMAs:

(We recommend you use these wrapper functions rather than Pallas’s built-in RDMA API, as we believe the wrappers express the underlying semantics slightly more clearly than their built-in equivalents.)

A peculiar feature of the pallas_rdma_start API is that it demands that you pass it a pointer into another device’s address space, to identify the destination buffer for your RDMA. But how can we obtain a pointer into a different device’s address space? How are we supposed to know where anything is? The answer is that we can exploit a trick: Pallas compiler guarantees that your kernel’s address space is laid out identically on each device, a pointer to a buffer on your current device can also identify the corresponding buffer on a different device!

As one final point, you may notice from reading the starter code that the RDMA API requires the use of DMA semaphores. DMA semaphores are a special type of Pallas array ref used to track progress on DMAs. Conceptually, you can think of each DMA semaphore as an integer counter:

You can allocate DMA semaphores for your kernel by filling in the exchange_with_neighbor_pallas_scratch_specs function:

You now have everything you need to implement an “exchange with neighbors” kernel!

Deliverable: Implement exchange_with_neighbor_pallas_kernel in collectives.py so that it sends data between devices in the expected pattern (you should see a rel_rmse of 0.00e+0 when you run it).

The code for your “exchange with neighbors” kernel will likely end up looking very simple, and that’s expected! This section is just to make sure we have all the fundamentals down before moving on to implementing real collectives.

Reduce-Scatter and All-Gather

We’re now ready to implement two more interesting collective kernels: reduce-scatter and all-gather. Both of these collectives conceptually take a one-dimensional vector of numbers on each device as input, and produce a one-dimensional vector of numbers of a different size as output. The semantics of each are as follows:

(There are also variants of reduce-scatter which use reduction operators other than “sum.” As long as the reduction operator is commutative and associative, the implementation will look more or less the same.)

In our case, the numbers in our arrays will all be of type float32.

To make our collective implementations play nicely with the TPU hardware, we won’t implement them in terms of one-dimensional Pallas arrays. Instead, we’ll view each input/output array as a sequence of (8, 128) tiles, so that each input/output array has a three-dimensional shape of the form (N, 8, 128). For the most part, you can treat tiles as indivisible units, and only worry about slicing or concatenating these three-dimensional arrays along their first dimension.

We want to develop our own implementations of reduce-scatter and all-gather in Pallas which achieve high throughput across a range of sizes:

Deliverable: Implement the functions reduce_scatter_pallas_kernel and all_gather_pallas_kernel to perform reduce-scatter and all-gather. For full credit, try to achieve at least the following effective bandwidths at each input array size:

Reduce-scatter target bandwidth:

All-gather target bandwidth:

(Note that “bandwidth” here refers to achieved outgoing bandwidth per device for reduce-scatter, and achieved incoming bandwidth per device for all-gather.)

Useful Features of Pallas

In your implementation, you’ll probably need to use a few more features of Pallas that we haven’t covered yet:

Implementation Tips

Part 2: Overlapping Collectives with Computation

Now that we’ve implemented reduce-scatter and all-gather, let’s see how we can put them to use!

Workload: A Tensor-Parallel Neural Network

In this section, we’ll implement a simple neural network model consisting of a sequence of dense matrix multiplications. Each layer of our network takes as input a batched activation matrix XX of shape

X:[Nbatch,din-out]X : [N_\text{batch}, d_\text{in-out}]

along with two weight matrices W1,W2W_1, W_2 of shape

W1:[din-out,dmiddle]W_1 : [d_\text{in-out}, d_\text{middle}] W2:[dmiddle,din-out]W_2 : [d_\text{middle}, d_\text{in-out}]

and performs the following operation:

Xf(XW1)W2X \gets f(X \, W_1) \, W_2

where ff is an elementwise nonlinearity (for this lab, we’ll use gelu). Our complete model consists of multiple such layers stacked together, each with different weights.

We could shard this model over the 4 TPUs in our ICI domain in several ways, such as data parallelism or pipeline parallelism. In this lab, we’ll implement tensor parallelism.

Tensor parallelism exploits the fact that each layer can be rewritten as a sum of contributions from independent weight subsets. If we decompose W1W_1 into NdevicesN_\text{devices} columns W1iW^i_1 and W2W_2 into NdevicesN_\text{devices} rows W2iW^i_2:

W1=[W10W11W12W13]W_1 = \begin{bmatrix} W^0_1 & W^1_1 & W^2_1 & W^3_1 \end{bmatrix} W2=[W20W21W22W33]W_2 = \begin{bmatrix} W^0_2 \\ W^1_2 \\ W^2_2 \\ W^3_3 \end{bmatrix} then each layer’s computation becomes: f(XW1)W2=f(X[W10W11W12W13])[W20W21W22W33]=f(XW10)W20+f(XW11)W21+f(XW12)W22+f(XW13)W23\begin{align*} f(X \, W_1) \, W_2 & = f(X \, \begin{bmatrix} W^0_1 & W^1_1 & W^2_1 & W^3_1 \end{bmatrix}) \, \begin{bmatrix} W^0_2 \\ W^1_2 \\ W^2_2 \\ W^3_3 \end{bmatrix} \\ & = f(X W^0_1)W^0_2 + f(X W^1_1)W^1_2 + f(X W^2_1)W^2_2 + f(X W^3_1)W^3_2 \end{align*} If XX is replicated across all devices, each device can independently compute one f(XW1i)W2if(X \, W^i_1) \, W^i_2 term, and then we sum the results at the end of the layer. How do we sum them? An all-reduce!

We’ll implement this tensor-parallel workload piece by piece, using the following hyperparameters: Nbatch=256N_\text{batch} = 256 din-out=1024d_\text{in-out} = 1024 dmiddle=16384d_\text{middle} = 16384 The starter code refers to these as: K1=din-out=1024K_1 = d_\text{in-out} = 1024 K2=dmiddle/Ndevices=4096K_2 = d_\text{middle} / N_\text{devices} = 4096

Since an all-reduce is a reduce-scatter followed by an all-gather, each layer’s computation from a single device’s perspective has the form:

All-Gather(Reduce-Scatter(f(XW1i)W2i))\text{All-Gather}(\text{Reduce-Scatter}(f(X W^i_1) W^i_2))

Since our model consists of multiple layers running in sequence, for pairs of adjacent layers in the middle of the network we can reassociate the All-Gather\text{All-Gather} at the end of one layer into the beginning of the next, so that a typical sequence of operations in the middle of the network looks like

Reduce-Scatter(f(All-Gather(Y)W1i)W2i)\text{Reduce-Scatter}(f(\text{All-Gather}(Y) W^i_1) W^i_2)

where YY is the sharded output of the previous layer’s reduce-scatter.

We can then decompose this computation into two main sub-computations:

Both operations involve both on-chip computation and cross-chip communication, which means we may be able to overlap computation with communication to keep hardware utilization high. These hybrid primitives are called collective matmuls.

This part of the lab has three main steps

  1. Implementing all-gather-matmul
  2. Implementing matmul-reduce-scatter
  3. Composing them together to obtain a complete tensor-parallel neural network

Implementation

For the sake of our implementation, we’ll make two significant simplifying assumptions:

  1. Our kernel will only run on core 0, ignoring core 1. This means giving up half the FLOP/s on the TPU, a cost we accept for this exercise.

  2. All our model’s weights will live entirely in VMEM for the duration of our kernel. This significantly limits the maximum size of model we can support, but it gives us enough space to fit three layers of weights in memory, which is good enough for our purposes.

In a real-world setting, you would probably want to generalize beyond both of these assumptions.

Before writing any Pallas code, consider a few theoretical questions:

Question 1 for final write-up: In your write-up, answer the following questions:

  1. Our TPU v5p has an advertised bfloat16 matmul throughput of ~459 TFLOP/s, or ~230 TFLOP/s per core. Assuming we only use one core, what is the minimum time it would take to run a single All-Gather(Y)W1i\text{All-Gather}(Y) W^i_1 matmul, if compute were the only constraint? How about a single ZW2iZ W^i_2 matmul?

  2. How much incoming and outgoing data do we need to communicate over each chip’s ICI in our all-gather-matmul operation? How about our matmul-reduce-scatter operation? What is the minimum time this would take for each operation, if ICI bandwidth were the only constraint?

  3. Of (1) and (2), which is larger?

Warm-Up: Matmul

Before implementing collective matmuls, let’s see how to implement a regular matmul in Pallas. Compared to CUDA, this is nearly trivial. Pallas exposes a function pl.dot (as in “dot product”) with the following signature:

pl.dot(a, b, trans_a=False, trans_b=False)

This pl.dot function takes two bfloat16 matrix values, and returns their matrix product as a float32 array. (Some 8-bit formats are also supported, but we won’t be using them in this lab.) You can cast its return value back to bfloat16 using

jnp.astype(result, jnp.bfloat16)

The trans_a and trans_b arguments allow you to optionally interpret either side as transposed prior to performing the matmul.

This pl.dot function is how Pallas exposes the TPU’s systolic array functional units to the programmer. Similar to the tensor cores on an NVIDIA GPU, the TPU’s systolic arrays provide hardware acceleration for performing matrix multiplies on fixed-size tiles of data. When you call pl.dot, Pallas will automatically decompose whatever matrix multiplication you asked for into a sequence of smaller matrix multiplication instructions sized appropriately for the systolic array. As a general heuristic, you can expect pl.dot to perform best when the size of each matrix dimension of both a and b is a multiple of 128.

To see how pl.dot works and measure its performance, fill in matmul_pallas_kernel:

Deliverable: In collective_matmul.py, implement matmul_pallas_kernel to compute the matrix multiplication of x_ref with w_ref. Using pl.dot, this should be a single line.

Collective Matmuls

Now we’re ready for the main event—implementing collective matmuls!

Deliverable: In collective_matmul.py, implement all_gather_matmul_pallas_kernel and matmul_reduce_scatter_pallas_kernel as described above, with a relative RMSE of less than 1e-2. Try to achieve the best performance you can. Both implementations should reach at least 195 TFLOPS/s.

Once you’ve implemented these kernels, move on to the complete tensor-parallel neural network:

Deliverable: In collective_matmul.py, implement neural_network_pallas_kernel with a relative RMSE of less than 1e-2. You may find it helpful to reuse code from your collective matmul implementations. The implementation should achieve at least 150 TFLOPS/s.

Pallas Tips

Unlike in Part 1, the kernels in this part fundamentally need to operate on multi-dimensional arrays, which we can no longer think of as just one-dimensional arrays of tiles. When working with multi-dimensional arrays in Pallas, there are a few tips which it may be helpful to keep in mind:

  1. You can index Pallas arrays multi-dimensionally by specifying multiple indices, like NumPy arrays.

    • For example, if you want to load a 2D region of a 2D array ref, you can use the syntax my_arr_ref[pl.ds(offset0, size0), pl.ds(offset1, size1)].

    • To obtain a ref to that region, you can use .at, as in my_array_ref.at[pl.ds(offset0, size0), pl.ds(offset1, size1)].

  2. Try to keep all indexing aligned to (8, 128) tiles if you can. This means offsets and sizes aligned to a multiple of 8 on the second-to-last dimension, and aligned to a multiple of 128 on the final dimension. Pallas’s compiler is often unable to handle operations that aren’t aligned to tile boundaries, and will throw an error if you try.

  3. Try to keep all source/destination buffers in your RDMAs contiguous in memory and tile-aligned.

    • You can think of tiles as being laid out in row-major order in memory, so, e.g., for an array ref of size [1024, 1024], the region arr_ref.at[pl.ds(32, 16), pl.ds(0, 1024)] would be contiguous, but the region arr_ref.at[pl.ds(0, 1024), pl.ds(32, 16)] would not.
  4. If Pallas complains that it can’t “prove” a dynamic index is a multiple of some value it expects it to be aligned to, you can use the function pl.multiple_of to obtain a value which the compiler knows is a multiple of the desired value.