Doing Cryptography in Tensorflow

TensorFlow, a popular machine learning framework, is actually more generally a platform for doing computation over tensors in the structure of a graph.

Cryptographic algorithms are frequently structured as the manipulation of vectors and matrices of bytes in the structure of a graph.

You might begin to see where this is going.

What follows is a completely frivolous experiment to implement various cryptographic algorithms in TensorFlow.

Important note: dear god please don’t use this code for real cryptography!

A Feistel Network, the algorithm behind DES.

Setup

First an import and define some utility functions we need.

import tensorflow as tf

def int_list_to_hex(l):
  return ''.join("{0:0{1}x}".format(x, 2) for x in l)

def int_list_to_string(l):
  return ''.join(chr(x) for x in l)

The One Time Pad

The One-Time Pad is the simplest and the most secure cipher. So why doesn’t everything use it?

For a key $k$ and a message $m$, OTP is defined as $k \oplus m = c$. You just XOR ($\oplus$) every byte of the key with the message. This means the key has to be as large as the message you send, and if you have a secure way of transmitting the key, why not use that way to send the message? OTP is still used for extremely security-critical things, like the president’s red phone.

In TensorFlow we define we can use the builtin method tf.bitwise.bitwise_xor to XOR vectors of arbitrary length, giving us a One-Time Pad.

message_str = "Hello this is a secret message."
message = tf.constant([ord(c) for c in message_str], tf.uint8)

key_uint32 = tf.Variable(tf.random_uniform(message.shape, minval=0, maxval=2**8, dtype=tf.int32))
key = tf.cast(key_uint32, tf.uint8)

encrypt_xor = tf.bitwise.bitwise_xor(message, key)
decrypt_xor = tf.bitwise.bitwise_xor(encrypt_xor, key)

with tf.Session().as_default() as session:
  session.run(tf.global_variables_initializer())
  print 'key:       ', int_list_to_hex(key.eval())
  print 'message:   ', int_list_to_string(message.eval())

  ciphertext = encrypt_xor.eval()
  print 'ciphertext:', int_list_to_hex(ciphertext)

  plaintext = decrypt_xor.eval()
  print 'plaintext: ', int_list_to_string(plaintext)
key:        283ad1b54c88fb82d6e240c2c42651d11e4a2691d94195361e4fa5dd1e2225
message:    Hello this is a secret message.
ciphertext: 605fbdd923a88feabf9160abb70630f16d2f45e3bc35b55b7b3cd6bc79470b
plaintext:  Hello this is a secret message.

The Data Encryption Standard

The Data Encryption Standard, or DES, was the crypto workhorse of the 1970’s-1990’s. The core of the algorithm is is a Feistel Network, which is a construction that lets you create an invertible function (i.e. a function that can encrypt a message and decrypt it to the same message) out of a non-invertible pseudo-random function (an example would be if you took the input and randomly assigned the output to a point on a number line).

The Feistel Network works by splitting the input into two halves (a left half and a right half) and feeding those halves through 16 rounds, as illustrated in the image to the right.

Given a pseudo-random function $F$, the next round of the encryption algorithm (left half: $L_{i+1}$, right half: $R_{i+1}$) is computed as:

$$L_{i+1} = R_i$$

$$R_{i+1} = L_i \oplus F(R_i,K_i)$$

Similarly the decryption algorithm is the reverse:

$$L_i = R_{i+1} \oplus F(L_{i+1},K_i)$$

$$R_i = L_{i+1}$$

BLOCK_SIZE = 128

def feistel_network_encrypt_round(round_key, left_0, right_0):
  """Run one encryption round of a Feistel network.

  Args:
    message: The plaintext to encrypt. Must be a tensor of shape (?, 128) with
             dtype uint8.
  Returns:
    right n+1
    left n+1
  """
  # (Using bitwise inversion instead of a true PRF)
  f_ri_ki = tf.bitwise.invert(right_0)
  right_plusone = tf.bitwise.bitwise_xor(left_0, f_ri_ki)

  return right_0, right_plusone


def feistel_network_decrypt_round(round_key, left_plusone, right_plusone):
  """Run one decryption round of a Feistel network.

  Args:
    ciphertext: The ciphertext to decrypt. Must be a tensor of shape (?, 128) 
      with dtype uint8.
  """
  # (Using bitwise inversion instead of a true PRF)
  f_lip1_ki = tf.bitwise.invert(left_plusone)
  right_0 = tf.bitwise.bitwise_xor(right_plusone, f_lip1_ki)

  return right_0, right_plusone


message_str = "abacadae"
input_tensor = tf.constant([ord(c) for c in message_str], tf.uint8)


with tf.Session().as_default() as session:
  # Keys here are used to seed the random shuffle
  key = [42, 42, 42]

  left_0, right_0 = tf.split(input_tensor, num_or_size_splits=2)

  print 'plaintext left  ', int_list_to_hex(left_0.eval())
  print 'plaintext right ', int_list_to_hex(right_0.eval())

  right_1, left_1 = feistel_network_encrypt_round(key[0], left_0, right_0)
  right_2, left_2 = feistel_network_encrypt_round(key[0], left_1, right_1)
  right_3, left_3 = feistel_network_encrypt_round(key[0], left_2, right_2)

  print 'ciphertext left ', int_list_to_hex(left_3.eval())
  print 'ciphertext right', int_list_to_hex(right_3.eval())
  
  left_2p, right_2p = feistel_network_decrypt_round(key[0], left_3, right_3)
  left_1p, right_1p = feistel_network_decrypt_round(key[0], left_2p, right_2p)
  left_0p, right_0p = feistel_network_decrypt_round(key[0], left_1p, right_1p)

  print 'plaintext left  ', int_list_to_hex(left_0p.eval())
  print 'plaintext right ', int_list_to_hex(right_0p.eval())
  print 'plaintext left  ', int_list_to_string(left_0p.eval())
  print 'plaintext right ', int_list_to_string(right_0p.eval())

Which outputs:

plaintext left   61626163
plaintext right  61646165
ciphertext left  fff9fff9
ciphertext right 61646165
plaintext left   61626163
plaintext right  61646165
plaintext left   abac
plaintext right  adae

And that’s it! I don’t think there are any practical use cases for writing cryptographic algorithms in TensorFlow, unless maybe if you need to encrypt an extremely large input by distributing your computation across many nodes.