How to Fake Multiple CPUs in JAX
Here’s how to emulate multiple CPUs when running JAX. This makes it easy to test multi-TPU/GPU code without actually needing the accelerators.
import os
os.environ['XLA_FLAGS'] = (
os.environ.get('XLA_FLAGS', '') +
" --xla_force_host_platform_device_count=8"
)
import jax
jax.devices()
[CpuDevice(id=0),
CpuDevice(id=1),
CpuDevice(id=2),
CpuDevice(id=3),
CpuDevice(id=4),
CpuDevice(id=5),
CpuDevice(id=6),
CpuDevice(id=7)]