How to Fake Multiple CPUs in JAX

Here’s how to emulate multiple CPUs when running JAX. This makes it easy to test multi-TPU/GPU code without actually needing the accelerators.

import os
os.environ['XLA_FLAGS'] = (
    os.environ.get('XLA_FLAGS', '') +
    " --xla_force_host_platform_device_count=8"
)

import jax
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]