Doing Post-Quantum Cryptography in JAX

In 2018 while studying ML and cryptography around the same time, I realized that many cryptographic algorithms can be expressed as computation graphs, the same ones supported by major ML frameworks, which led to a completely frivolous attempt to implement cryptographic algorithms in TensorFlow, just to see if it would work.

The world has changed a lot since 2018. JAX is growing in popularity as an ML framework, and in the cryptography space, post-quantum cryptography has gone from a mostly theoretical threat to a slightly more real one. So continuing the motivation of the original post, I decided to look into implementing a post-quantum cryptography algorithm in JAX.

Just as a heads up, it goes without saying but please do not use any of this code for real cryptography.

Lattice-based cryptography is one of the leading areas of post-quantum algorithms. Unlike algorithms like Diffie-Hellman or RSA that rely on computing discrete logarithms or factoring large primes, both of which become easier with a quantum computer using Shor’s algorithm, it’s currently thought that lattice-based problems are resistant to quantum attacks because there is no known quantum algorithm to solve them more efficiently with a quantum computer.

Learning with Errors is one of the most popular lattice-based schemes. The following code demonstrates a Diffie-Hellman-style public key encryption scheme using LWE in JAX.

# Global parameters

# Dimension of the secret key `s`
n = 4

# Modulus - the math will be modulo this number in order to put the computation in a finite field
q = 23 

# Number of bits in each message
m = 8

The first step is to generate a keypair. The private key (s) is a vector of length 4 (chosen above).

# Generate Keypair
key_rng = random.PRNGKey(42)
secret_rng, pub_rng = random.split(key_rng)

# Secret Key
s = random.randint(secret_rng, shape=(n,), minval=0, maxval=q)
s
# => Array([ 1, 13,  3, 22], dtype=int32)

The public key is a matrix A and vector b.

# The `A` random matrix is the first component of the public key.
A = random.randint(pub_rng, shape=(m, n), minval=0, maxval=q)

# `e` is the error vector, used to generate `b` then thrown out.
e = random.randint(pub_rng, shape=(m,), minval=-1, maxval=2)

print(A.shape) # => (8, 4)
print(e.shape) # => (8,)
# Generate second component of public key, vector `b`, from
# public matrix `A` with the secret key `s` plus the error vector,
# and modulo `q` to put it within the finite field.
b = (A @ s + e) % q
b
# => Array([ 2, 18,  6,  1, 20,  9, 13, 10], dtype=int32)

We now have public key (A, b) and private key s. We can use this to encrypt and decrypt messages in an asymmetric encryption scheme.

Encryption uses public key (A, b) to encrypt a binary vector message. r introduces randomness into the process (like e above), to ensure the same message does not always map to the same ciphertext, for security. The result is mapped into the finite field modulo q —- the plaintext bit 0 will be mapped close to zero (q // 0), and the plaintext bit close to 1 will be mapped close (but not exactly!) to the middle of the field (q // 2). For example, because q is 23, then 0 will be 0 and 1 will be 11 (because 23 // 2 = 11).

This produces ciphertext (c1, c2) where c1 is a vector with the same size as the secret key (4), and c2 is a vector with the same length as the message.

message = jnp.array([1, 1, 0, 0, 1, 1, 0, 0])
message
# => Array([1, 1, 0, 0, 1, 1, 0, 0], dtype=int32)
enc_rng = random.PRNGKey(24)

# Generate a random binary vector as the mask
r = random.bernoulli(enc_rng, shape=(m,))
r = r.astype(jnp.int32)

# Encryption: c1 = r \cdot A, c2 = (r \cdot b + msg * q // 2) % q
c1 = jnp.dot(r, A) % q
c2 = (jnp.dot(r, b) + message * (q // 2)) % q
print('Ciphertext c1:', c1)
print('Ciphertext c2:', c2)
# => Ciphertext c1: [ 2 15  3 12]
# => Ciphertext c2: [21 21 10 10 21 21 10 10]
# Note: in a real system this would be more random.

Decryption combines the secret key s with the ciphertext (c1, c2):

intermediate = (c2 - c1 @ s) % q
intermediate
# => Array([11, 11,  0,  0, 11, 11,  0,  0], dtype=int32)

This intermediate result gives us an array of values where 11, which we saw during encryption, corresponds to the high bit 1. Filtering values above a threshold (q // 4) lets us convert the value back into the original boolean message.

decrypted_message = jnp.where(intermediate > q // 4, 1, 0)
decrypted_message
# => Array([1, 1, 0, 0, 1, 1, 0, 0], dtype=int32)