Doing Post-Quantum Cryptography in JAX
In 2018 while studying ML and cryptography around the same time, I realized that many cryptographic algorithms can be expressed as computation graphs, the same ones supported by major ML frameworks, which led to a completely frivolous attempt to implement cryptographic algorithms in TensorFlow, just to see if it would work.
The world has changed a lot since 2018. JAX is growing in popularity as an ML framework, and in the cryptography space, post-quantum cryptography has gone from a mostly theoretical threat to a slightly more real one. So continuing the motivation of the original post, I decided to look into implementing a post-quantum cryptography algorithm in JAX.
Just as a heads up, it goes without saying but please do not use any of this code for real cryptography.
Lattice-based cryptography is one of the leading areas of post-quantum algorithms. Unlike algorithms like Diffie-Hellman or RSA that rely on computing discrete logarithms or factoring large primes, both of which become easier with a quantum computer using Shor’s algorithm, it’s currently thought that lattice-based problems are resistant to quantum attacks because there is no known quantum algorithm to solve them more efficiently with a quantum computer.
Learning with Errors is one of the most popular lattice-based schemes. The following code demonstrates a Diffie-Hellman-style public key encryption scheme using LWE in JAX.
# Global parameters
# Dimension of the secret key `s`
n = 4
# Modulus - the math will be modulo this number in order to put the computation in a finite field
q = 23
# Number of bits in each message
m = 8
The first step is to generate a keypair. The private key (s
) is a vector of
length 4 (chosen above).
# Generate Keypair
key_rng = random.PRNGKey(42)
secret_rng, pub_rng = random.split(key_rng)
# Secret Key
s = random.randint(secret_rng, shape=(n,), minval=0, maxval=q)
s
# => Array([ 1, 13, 3, 22], dtype=int32)
The public key is a matrix A
and vector b
.
# The `A` random matrix is the first component of the public key.
A = random.randint(pub_rng, shape=(m, n), minval=0, maxval=q)
# `e` is the error vector, used to generate `b` then thrown out.
e = random.randint(pub_rng, shape=(m,), minval=-1, maxval=2)
print(A.shape) # => (8, 4)
print(e.shape) # => (8,)
# Generate second component of public key, vector `b`, from
# public matrix `A` with the secret key `s` plus the error vector,
# and modulo `q` to put it within the finite field.
b = (A @ s + e) % q
b
# => Array([ 2, 18, 6, 1, 20, 9, 13, 10], dtype=int32)
We now have public key (A, b)
and private key s
.
We can use this to encrypt and decrypt messages in an asymmetric
encryption scheme.
Encryption uses public key (A, b)
to encrypt a binary vector message
. r
introduces randomness into the process (like e
above), to ensure the same
message does not always map to the same ciphertext, for security. The result is
mapped into the finite field modulo q
—- the plaintext bit 0
will be mapped close to zero (q // 0
), and the plaintext bit close to 1
will
be mapped close (but not exactly!) to the middle of the field (q // 2
). For
example, because q
is 23, then 0
will be 0
and 1
will be 11
(because 23 // 2 = 11
).
This produces ciphertext (c1, c2)
where c1
is a vector with the same size
as the secret key (4), and c2
is a vector with the same length as the
message.
message = jnp.array([1, 1, 0, 0, 1, 1, 0, 0])
message
# => Array([1, 1, 0, 0, 1, 1, 0, 0], dtype=int32)
enc_rng = random.PRNGKey(24)
# Generate a random binary vector as the mask
r = random.bernoulli(enc_rng, shape=(m,))
r = r.astype(jnp.int32)
# Encryption: c1 = r \cdot A, c2 = (r \cdot b + msg * q // 2) % q
c1 = jnp.dot(r, A) % q
c2 = (jnp.dot(r, b) + message * (q // 2)) % q
print('Ciphertext c1:', c1)
print('Ciphertext c2:', c2)
# => Ciphertext c1: [ 2 15 3 12]
# => Ciphertext c2: [21 21 10 10 21 21 10 10]
# Note: in a real system this would be more random.
Decryption combines the secret key s
with the ciphertext (c1, c2)
:
intermediate = (c2 - c1 @ s) % q
intermediate
# => Array([11, 11, 0, 0, 11, 11, 0, 0], dtype=int32)
This intermediate result gives us an array of values where 11
, which we saw during
encryption, corresponds to the high bit 1
. Filtering values above a threshold
(q // 4
) lets us convert the value back into the original boolean message.
decrypted_message = jnp.where(intermediate > q // 4, 1, 0)
decrypted_message
# => Array([1, 1, 0, 0, 1, 1, 0, 0], dtype=int32)