Lab 11: Meet the TPU
Prologue: Logistics
Pair Programming
Like Lab 9 and Lab 10, you’ll be working in teams of two for this lab as well.
Due Dates
For this lab, you’ll be turning in the following deliverables:
-
Checkpoint: Due Monday, November 24, 11:59 pm. On Gradescope, please provide a short update on your progress on the lab.
-
Final Submission: Due Friday, December 5, 11:59 pm. Submit your completed code for all parts of the lab, as well as a PDF writeup containing your answers for Question 1 of the final writeup.
Starter Code
You can get the starter code for this lab by cloning the lab repository:
git clone git@github.com:accelerated-computing-class/lab11.git
You can run the Python code using uvx telerun like usual.
Goals for This Lab
Unlike in previous labs, where we programmed a single NVIDIA GPU, in this lab we will be programming a 2x2 mesh of Google TPU v5p devices. This means we will now have to think carefully about networking and inter-device communication. So, what do we have to work with?
Interconnect Topology
In this lab, we will be programming a scale-up domain of multiple devices connected by an inter-chip interconnect (“ICI”) fabric.
In our particular setup, we have 4 TPU v5p devices arranged in a 2x2 grid, with
each device connected to one horizontal neighbor and one vertical neighbor. We
can think of this topology as a one-dimensional ring of 4 devices, and label
our devices clockwise from the top-left with indices 0, 1, 2, 3 to mark their
positions in the ring:
In real-world use cases, the TPU v5p platform supports connecting much larger numbers of chips together, reaching scale-up domain sizes as large as thousands of chips. In addition to supporting very large scale-up domains, TPU v5p also supports connecting chips together in a 3D mesh topology, allowing each device to connect to up to 6 neighboring devices, as opposed to the 2 neighbors per device we use in this lab.
Each TPU v5p ICI link supports up to ~92 GB/s of bandwidth in each direction, with a per-hop latency of roughly ~1 μs. Bandwidth is not fungible between links or between directions.
On-Chip Resources
Inside each TPU v5p, we have two identical cores. Each core is a single-threaded programmable processor, and executes sequences of instructions in a proprietary VLIW-flavored ISA. Each TPU core is roughly analogous to a single core on a CPU, or to a single warp scheduler on a GPU—it has an SRAM scratchpad, a register file, dynamic control logic, and functional units for integer math, floating point math, and other operations.
The TPU makes up for its low degree of multi-core parallelism by implementing a
very high degree of SIMD parallelism. For 32-bit floating-point data, the
TPU natively operates in terms of 1024-element-wide vectors. Each 1024-wide
vector is usually thought of as a two-dimensional array of shape (8, 128),
which Google calls a “tile.”
Compared to CPUs or GPUs, one especially notable feature of TPUs is that each TPU core possesses a relatively large, high-bandwidth SRAM scratchpad called “VMEM” (short for “Vector Memory”). On our TPU v5p, VMEM is 64 MiB per core. For comparison, an H100 GPU has only 0.3% as much shared memory per SM (227 KiB vs 64 MiB) and ~23% as much total SRAM scratchpad capacity per chip (30 MiB vs 128 MiB).
Although we will mostly be using VMEM for Part 1 of this lab, it’s worth noting that each of our TPU v5p devices also has 95 GB of attached HBM2e memory shared between the two cores. By contrast to a GPU’s complex memory hierarchy, HBM and VMEM are the only two significant general-purpose memory resources on a TPU; the TPU has no L2 cache, L1 cache, or high-capacity register file like one would find on an NVIDIA GPU.
Further reading: For more information on the hardware architecture of TPUs, the free online book “How to Scale Your Model” is an excellent resource.
Part 1: Implementing Collectives in Pallas
In Part 1 of this lab, we will implement high-performance TPU kernels to carry out two fundamental collective operations: reduce-scatter and all-gather. These two operations can be composed together to implement all-reduce.
Background: Pallas and TPUs
In 2025, there are two main ways to program TPUs:
-
JAX, a high-level bulk-array-processing language embedded in Python, operating at roughly the same level of abstraction as NumPy or PyTorch.
-
Pallas, a lower-level array-processing language also embedded in Python, operating at roughly the same level of abstraction as Triton or CUDA with CUTLASS templates. (In ancient Greek mythology, Pallas was the daughter of Triton.)
In this lab, we will be exclusively programming in Pallas. You will not need to interact with JAX at any point (except insofar as Pallas shares some library infrastructure with JAX).
Coming from CUDA, Pallas is familiar to us in many ways:
-
Pallas is an imperative, stateful language—it supports pointers and random-access reads and writes through those pointers.
-
Pallas does not enforce strong safety guarantees for user code—it’s possible to introduce bugs like out-of-bounds indexing or race conditions.
-
Pallas maps closely onto the hardware—Pallas is currently the lowest-level interface publicly available for programming TPUs, and can express most kernels that the TPU is physically capable of executing.
However, Pallas differs from CUDA in two key ways: it uses a tracing-based compilation model and an array-based programming abstraction.
Tracing-Based Design
The way you define a Pallas program is by “tracing” a Python program. This means that execution happens in two phases:
-
First, you run your Python program on special “traced”/“symbolic” objects representing values which will exist at runtime in your kernel, but which have no concrete value yet. When your Python code performs operations on these traced values, those operations emit instructions which are recorded by Pallas’s tracing infrastructure, but no code actually executes on the TPU.
-
Second, once tracing is complete, Pallas compiles a kernel from the emitted instructions. You can then run this kernel on the TPU as many times as you like, with different inputs each time.
Pallas’s tracing-based architecture has a few notable implications:
-
Python control flow is not kernel control flow: When you use Python control flow constructs like
iforwhileinside a traced Pallas kernel, the control flow happens at trace time, not at runtime. This means any control flow decisions you make in Python cannot depend on dynamic values.If you want to write a kernel which branches or loops on the TPU based on some dynamic condition, you need to use special traced Pallas constructs which emit branch or loop instructions, like
pl.when,lax.cond,lax.fori_loop, orlax.while_loop(all of which work in Pallas, despite some of them looking like they belong exclusively to JAX). -
Python
printstatements are not kernel print statements: Similar to control flow, anyprintstatements you put in your code will run at trace time, not at kernel runtime. If you want to print a dynamic value in your kernel, you must use the traced print functionjax.debug.print, and setENABLE_DEBUG=Truein the starter code. -
The only thing that matters for performance is what sequence of instructions your code emits, not how it emits them. At trace time, you can use any features of Python you want—strings, dictionaries, custom classes, higher-order functions, etc.—to structure how your code emits instructions without affecting runtime performance.
You can read more about tracing in JAX (which works similarly to Pallas) here.
Array-Based Design
Unlike CUDA, Pallas has convenient built-in support for performing bulk operations on multi-dimensional arrays. In fact, multi-dimensional arrays are the primary abstraction Pallas provides for working with numerical data. Arrays in Pallas come in two flavors:
-
Array refs, which are like pointers with associated multi-dimensional size information. Like pointers, Pallas array refs may be loaded from, stored to, offseted into, and passed to APIs which operate on regions of memory, such as Pallas’s remote-DMA API.
-
Array values, representing numerical data which the compiler will (usually) temporarily materialize in vector registers, and which may be passed directly to and returned from vectorized math operations.
When you load from a Pallas array ref, what you get is an array value. Similarly, when you store to a Pallas array ref, you must provide an array value.
The Pallas compiler automatically exploits SIMD and VLIW parallelism
on your behalf as part of its array-based programming model. For example, if x
and y are both traced Pallas array values of shape (2048, 1024) with
elements of type fp32, then you can compute their elementwise sum simply by
writing the expression x + y, and the compiler will automatically decompose
that bulk array operation into a sequence of individual vector instructions,
packed efficiently into VLIW bundles.
With a few exceptions, the Pallas compiler generally requires that shape
information for all arrays must be known at trace time. This imposes some
restrictions on what you are able to do with arrays, but it also means that your
own Python code can query the concrete .shape of most Pallas array refs or
array values at trace time.
Warm-Up: exchange_with_neighbor
With all that preamble out of the way, let’s write our first Pallas program!
Since our goal for this lab is to make use of ICI, we’ll start by writing a simple program that sends data between multiple devices. Our objective in this section will be to implement an “exchange with neighbor” operation, which should copy arrays between devices according to the following pattern:
- Device 0 sends its input array to device 1
- Device 1 sends its input array to device 0
- Device 2 sends its input array to device 3
- Device 3 sends its input array to device 2
To do this, you’ll need to fill in the body of the function
exchange_with_neighbor_pallas_kernel. A few notes on the execution model of
this function:
-
The starter code already contains all the necessary boilerplate to ensure that the Python code in
exchange_with_neighbor_pallas_kernelwill be executed in a traced Pallas context, and used to define a Pallas kernel. -
A copy of the Pallas kernel defined by
exchange_with_neighbor_pallas_kernelwill execute on each of the 4 TPUs in our mesh simultaneously, in SPMD fashion. This is similar to how, when you launch a CUDA kernel, a copy of it will execute on every thread and block in your grid. -
Although each of our TPUs has 2 cores, we will only use core 0 on each device, and completely ignore core 1. (You don’t need to do anything special to implement this; your kernel will run on core 0 by default.) Using just a single core per device is sufficient to completely saturate the bandwidth of our ICI links, and since all our kernels for Part 1 will be bottlenecked by communication rather than computation, that means a single core is good enough for our purposes.
In our kernel, we’re going to need to transfer data between devices. To do this, we can use Pallas’s remote-DMA (“RDMA”) functionality. The starter code provides the following convenience functions for working with RDMAs:
-
pallas_rdma_start: Start an asynchronous copy from the current device to a remote device. -
pallas_rdma_wait_send: Wait until an in-flight RDMA originating on the current device is done reading data. -
pallas_rdma_wait_recv: Wait until an in-flight RDMA targeting the current device is done writing data.
(We recommend you use these wrapper functions rather than Pallas’s built-in RDMA API, as we believe the wrappers express the underlying semantics slightly more clearly than their built-in equivalents.)
A peculiar feature of the pallas_rdma_start API is that it demands that you
pass it a pointer into another device’s address space, to identify the
destination buffer for your RDMA. But how can we obtain a pointer into a
different device’s address space? How are we supposed to know where anything is?
The answer is that we can exploit a trick:
Pallas compiler guarantees that your kernel’s address space is laid out
identically on each device, a pointer to a buffer on your current device can
also identify the corresponding buffer on a different device!
As one final point, you may notice from reading the starter code that the RDMA API requires the use of DMA semaphores. DMA semaphores are a special type of Pallas array ref used to track progress on DMAs. Conceptually, you can think of each DMA semaphore as an integer counter:
- As the DMA makes progress, the DMA engine increments the counter to reflect the number of bytes transferred so far.
- To wait on the DMA, the TPU core waits until the counter’s value reaches the expected total number of bytes to transfer, and then decrements the counter to zero.
You can allocate DMA semaphores for your kernel by filling in the
exchange_with_neighbor_pallas_scratch_specs function:
- If you allocate a single DMA semaphore with
pltpu.SemaphoreType.DMA, you can pass the resulting array ref directly to RDMA API functions. - If you allocate an array of multiple DMA semaphores with
pltpu.SemaphoreType.DMA(shape=(N,)), you can obtain array refs pointing to individual semaphores in the array using the syntaxmy_semaphore_array.at[my_index].
You now have everything you need to implement an “exchange with neighbors” kernel!
Deliverable: Implement
exchange_with_neighbor_pallas_kernelincollectives.pyso that it sends data between devices in the expected pattern (you should see arel_rmseof0.00e+0when you run it).
The code for your “exchange with neighbors” kernel will likely end up looking very simple, and that’s expected! This section is just to make sure we have all the fundamentals down before moving on to implementing real collectives.
Reduce-Scatter and All-Gather
We’re now ready to implement two more interesting collective kernels: reduce-scatter and all-gather. Both of these collectives conceptually take a one-dimensional vector of numbers on each device as input, and produce a one-dimensional vector of numbers of a different size as output. The semantics of each are as follows:
- Reduce-scatter combines all devices’ input arrays by summing them elementwise, and shards the output of that elementwise sum across devices, splitting the result into evenly-sized chunks.
- All-gather concatenates all devices’ input arrays together into a single larger array, and replicates that concatenated array across devices.
(There are also variants of reduce-scatter which use reduction operators other than “sum.” As long as the reduction operator is commutative and associative, the implementation will look more or less the same.)
In our case, the numbers in our arrays will all be of type float32.
To make our collective implementations play nicely with the TPU hardware, we
won’t implement them in terms of one-dimensional Pallas arrays. Instead, we’ll
view each input/output array as a sequence of (8, 128) tiles, so that each
input/output array has a three-dimensional shape of the form (N, 8, 128).
For the most part, you can treat tiles as indivisible units, and only worry
about slicing or concatenating these three-dimensional arrays along their first
dimension.
We want to develop our own implementations of reduce-scatter and all-gather in Pallas which achieve high throughput across a range of sizes:
Deliverable: Implement the functions
reduce_scatter_pallas_kernelandall_gather_pallas_kernelto perform reduce-scatter and all-gather. For full credit, try to achieve at least the following effective bandwidths at each input array size:Reduce-scatter target bandwidth:
(16, 8, 128):14 GB/s(128, 8, 128):78 GB/s(1024, 8, 128):160 GB/s(2048, 8, 128):170 GB/s(4096, 8, 128):178 GB/sAll-gather target bandwidth:
(4, 8, 128):14 GB/s(32, 8, 128):78 GB/s(256, 8, 128):160 GB/s(512, 8, 128):170 GB/s(1024, 8, 128):178 GB/s
(Note that “bandwidth” here refers to achieved outgoing bandwidth per device for reduce-scatter, and achieved incoming bandwidth per device for all-gather.)
Useful Features of Pallas
In your implementation, you’ll probably need to use a few more features of Pallas that we haven’t covered yet:
-
Array ref loads/stores:
-
You can load an array value of size
Mfrom offsetiinside a larger array ref in VMEM using the syntaxmy_array_ref[pl.ds(i, M)]. -
You can store an array value of size
Mto offsetiinside a larger array ref in VMEM using the syntaxmy_array_ref[pl.ds(i, M)] = my_array_value. -
If your array has more than one dimension, like our three-dimensional
(N, 8, 128)arrays, the above expressions will slice along the first dimension, and any trailing dimensions will be broadcast like in NumPy. This is convenient if you want to treat tiles as indivisible units. -
Mmust be a concrete integer known at trace time.
-
-
Array ref slices:
-
You can refer to sub-slices of an array ref using its
.atproperty: to obtain a slice of sizeMstarting at indexi, you can use the syntaxmy_array_ref.at[pl.ds(i, M)].-
Slicing an array ref with
.atwill return another array ref, not an array value. This is analogous to computing a pointer offset in C/CUDA, as opposed to loading/storing through a pointer. -
The array ref returned by
.atwill point at the same underlying memory as the original array ref. -
In this case,
Mdoesn’t always have to be a concrete integer known at trace time—for some downstream operations, including (R)DMAs, the Pallas compiler will accept passing a traced value for the size parameterM.
-
-
-
Scratch buffers:
- You can allocate extra VMEM buffers for use in your kernel by using
pltpu.VMEMinreduce_scatter_pallas_scratch_specs/all_gather_pallas_scratch_specs.
- You can allocate extra VMEM buffers for use in your kernel by using
Implementation Tips
-
Be careful if you reuse send/recv semaphores! A given DMA semaphore can only participate in one RDMA operation at a time, and must be waited on before it can be reused by a different RDMA. This means in particular that before you start an RDMA to a remote device, you must make sure that the destination device has already waited on any prior RDMAs in your kernel which use the same
dst_recv_semyou specified.-
If you want to reuse DMA semaphores, you can enforce the necessary synchronization by using an additional
pltpu.SemaphoreType.REGULARsemaphore, which is effectively an integer counter exposing the following API:-
pltpu.semaphore_signal(dst_regular_sem, inc_count, device_id=dst_device_id): increment the value ofdst_regular_semon a remote device. -
pltpu.semaphore_wait(regular_sem, dec_count): wait untilregular_semon the current device reaches at leastdec_count, then decrement it bydec_countand continue.
-
-
Alternatively, the approach we recommend is to not reuse any DMA semaphores. For the array sizes we’re interested in, the TPU should have more than enough available semaphore memory for you to be able to get away with using a unique DMA semaphore for every RDMA.
-
-
Be careful if you reuse buffers in VMEM! Similar concerns as apply to semaphore reuse also apply to using the same VMEM buffer as the destination of multiple RDMAs. You can use
REGULARsemaphores to synchronize repeated access to VMEM buffers across devices, but you may find it easier to simply avoid ever RDMAing to the same interval of VMEM twice. The TPU has enough VMEM capacity to accommodate this. -
All sends must be waited on. Every time your kernel calls
pallas_rdma_start, you must make sure it also eventually makes exactly one corresponding call topallas_rdma_wait_sendbefore the kernel exits. -
Only send to direct neighbors. The TPU hardware supports routing RDMAs to devices which are not your device’s direct neighbors, but doing so will consume bandwidth on multiple ICI links between you and the destination device, causing congestion. To take maximum advantage of the ICI bandwidth available to you, we recommend only ever sending to direct neighbors.
-
Back-to-back RDMAs are a great way to hide latency. Whenever you perform a blocking operation (e.g.
pallas_rdma_wait_sendorpallas_rdma_wait_recv), try to make sure that you will have at least one in-flight RDMA which will continue running for the whole time the core is blocked, so that you are always running the ICI links at maximum utilization.- A good pattern to achieve this is to pipeline your RDMAs by launching multiple RDMAs back-to-back, and then waiting on early RDMAs in the sequence while later RDMAs in the sequence are still in flight.
Part 2: Overlapping Collectives with Computation
Now that we’ve implemented reduce-scatter and all-gather, let’s see how we can put them to use!
Workload: A Tensor-Parallel Neural Network
In this section, we’ll implement a simple neural network model consisting of a sequence of dense matrix multiplications. Each layer of our network takes as input a batched activation matrix of shape
along with two weight matrices of shape
and performs the following operation:
where is an elementwise nonlinearity (for this lab, we’ll use gelu). Our complete model consists of multiple such layers stacked together, each with different weights.
We could shard this model over the 4 TPUs in our ICI domain in several ways, such as data parallelism or pipeline parallelism. In this lab, we’ll implement tensor parallelism.
Tensor parallelism exploits the fact that each layer can be rewritten as a sum of contributions from independent weight subsets. If we decompose into columns and into rows :
then each layer’s computation becomes: If is replicated across all devices, each device can independently compute one term, and then we sum the results at the end of the layer. How do we sum them? An all-reduce!
We’ll implement this tensor-parallel workload piece by piece, using the following hyperparameters: The starter code refers to these as:
Since an all-reduce is a reduce-scatter followed by an all-gather, each layer’s computation from a single device’s perspective has the form:
Since our model consists of multiple layers running in sequence, for pairs of adjacent layers in the middle of the network we can reassociate the at the end of one layer into the beginning of the next, so that a typical sequence of operations in the middle of the network looks like
where is the sharded output of the previous layer’s reduce-scatter.
We can then decompose this computation into two main sub-computations:
- All-Gather-Matmul:
- Matmul-Reduce-Scatter:
- (where is the intermediate tensor returned by )
Both operations involve both on-chip computation and cross-chip communication, which means we may be able to overlap computation with communication to keep hardware utilization high. These hybrid primitives are called collective matmuls.
This part of the lab has three main steps
- Implementing all-gather-matmul
- Implementing matmul-reduce-scatter
- Composing them together to obtain a complete tensor-parallel neural network
Implementation
For the sake of our implementation, we’ll make two significant simplifying assumptions:
-
Our kernel will only run on core 0, ignoring core 1. This means giving up half the FLOP/s on the TPU, a cost we accept for this exercise.
-
All our model’s weights will live entirely in VMEM for the duration of our kernel. This significantly limits the maximum size of model we can support, but it gives us enough space to fit three layers of weights in memory, which is good enough for our purposes.
In a real-world setting, you would probably want to generalize beyond both of these assumptions.
Before writing any Pallas code, consider a few theoretical questions:
Question 1 for final write-up: In your write-up, answer the following questions:
Our TPU v5p has an advertised bfloat16 matmul throughput of ~459 TFLOP/s, or ~230 TFLOP/s per core. Assuming we only use one core, what is the minimum time it would take to run a single matmul, if compute were the only constraint? How about a single matmul?
How much incoming and outgoing data do we need to communicate over each chip’s ICI in our all-gather-matmul operation? How about our matmul-reduce-scatter operation? What is the minimum time this would take for each operation, if ICI bandwidth were the only constraint?
Of (1) and (2), which is larger?
Warm-Up: Matmul
Before implementing collective matmuls, let’s see how to implement a regular
matmul in Pallas. Compared to CUDA, this is nearly trivial. Pallas exposes a
function
pl.dot
(as in “dot product”) with the following signature:
This pl.dot function takes two bfloat16 matrix values, and returns their
matrix product as a float32 array. (Some 8-bit formats are also supported, but
we won’t be using them in this lab.) You can cast its return value back to
bfloat16 using
The trans_a and trans_b arguments allow you to optionally interpret either
side as transposed prior to performing the matmul.
This pl.dot function is how Pallas exposes the TPU’s systolic
array functional units to the
programmer. Similar to the tensor cores on an NVIDIA GPU, the TPU’s systolic
arrays provide hardware acceleration for performing matrix multiplies on
fixed-size tiles of data. When you call pl.dot, Pallas will automatically
decompose whatever matrix multiplication you asked for into a sequence of
smaller matrix multiplication instructions sized appropriately for the systolic
array. As a general heuristic, you can expect pl.dot to perform best when the
size of each matrix dimension of both a and b is a multiple of 128.
To see how pl.dot works and measure its performance, fill in
matmul_pallas_kernel:
Deliverable: In
collective_matmul.py, implementmatmul_pallas_kernelto compute the matrix multiplication ofx_refwithw_ref. Usingpl.dot, this should be a single line.
Collective Matmuls
Now we’re ready for the main event—implementing collective matmuls!
Deliverable: In
collective_matmul.py, implementall_gather_matmul_pallas_kernelandmatmul_reduce_scatter_pallas_kernelas described above, with a relative RMSE of less than1e-2. Try to achieve the best performance you can. Both implementations should reach at least 195 TFLOPS/s.
Once you’ve implemented these kernels, move on to the complete tensor-parallel neural network:
Deliverable: In
collective_matmul.py, implementneural_network_pallas_kernelwith a relative RMSE of less than1e-2. You may find it helpful to reuse code from your collective matmul implementations. The implementation should achieve at least 150 TFLOPS/s.
Pallas Tips
Unlike in Part 1, the kernels in this part fundamentally need to operate on multi-dimensional arrays, which we can no longer think of as just one-dimensional arrays of tiles. When working with multi-dimensional arrays in Pallas, there are a few tips which it may be helpful to keep in mind:
-
You can index Pallas arrays multi-dimensionally by specifying multiple indices, like NumPy arrays.
-
For example, if you want to load a 2D region of a 2D array ref, you can use the syntax
my_arr_ref[pl.ds(offset0, size0), pl.ds(offset1, size1)]. -
To obtain a ref to that region, you can use
.at, as inmy_array_ref.at[pl.ds(offset0, size0), pl.ds(offset1, size1)].
-
-
Try to keep all indexing aligned to
(8, 128)tiles if you can. This means offsets and sizes aligned to a multiple of8on the second-to-last dimension, and aligned to a multiple of128on the final dimension. Pallas’s compiler is often unable to handle operations that aren’t aligned to tile boundaries, and will throw an error if you try. -
Try to keep all source/destination buffers in your RDMAs contiguous in memory and tile-aligned.
- You can think of tiles as being laid out in row-major order in memory, so,
e.g., for an array ref of size
[1024, 1024], the regionarr_ref.at[pl.ds(32, 16), pl.ds(0, 1024)]would be contiguous, but the regionarr_ref.at[pl.ds(0, 1024), pl.ds(32, 16)]would not.
- You can think of tiles as being laid out in row-major order in memory, so,
e.g., for an array ref of size
-
If Pallas complains that it can’t “prove” a dynamic index is a multiple of some value it expects it to be aligned to, you can use the function
pl.multiple_ofto obtain a value which the compiler knows is a multiple of the desired value.