Doing Cryptography in TensorFlow

On the left: the Feistel Network from the DES cipher, implemented below. On the right: a deep neural network.

TensorFlow is a popular machine learning framework. If you look under the hood, TensorFlow is a general platform for doing computation over tensors in the structure of a graph.

While studying Cryptography, a completely different field of Computer Science, one might begin to notice that cryptographic algorithms are also frequently structured as the manipulation of vectors and matrices of bytes in the structure of a graph. You might see where this is going.

What follows is a completely frivolous experiment to implement various cryptographic algorithms in TensorFlow.

Important note: do not use this code for real cryptography!

# Imports and helper functions
import tensorflow as tf

def int_list_to_hex(l):
    return ''.join("{0:0{1}x}".format(x, 2) for x in l)

def int_list_to_string(l):
    return ''.join(chr(x) for x in l)

The One Time Pad

The One-Time Pad is the simplest and the most secure cipher.

For a key and a message , the OTP is defined as . You XOR ( ) every byte of the key with every byte of the message. This gives you 100% perfect secrecy, since XOR’ing a uniform random distribution with anything always gives you a uniform random distribution.

So why doesn’t everything use the One Time Pad if it’s the most secure cipher?

The big downside to OTP is that the key must be as large as the message you send, and each key can only be used strictly once (i.e. “One Time” Pad), so the problem becomes how do you distribute the key? If you already have a secure way of transmitting keys, why not use that method to send the message? The One Time Pad isn’t frequently used for this reason, but it is still used for extremely security-critical purposes, like the president’s red phone.

In TensorFlow we define we can use the builtin method tf.bitwise.bitwise_xor to XOR vectors of arbitrary length, giving us a one-time pad.

message_str = "Hello this is a secret message."
message = tf.constant([ord(c) for c in message_str], tf.uint8)

key_uint32 = tf.Variable(tf.random_uniform(message.shape, minval=0, maxval=2**8, dtype=tf.int32))
key = tf.cast(key_uint32, tf.uint8)

encrypt_xor = tf.bitwise.bitwise_xor(message, key)
decrypt_xor = tf.bitwise.bitwise_xor(encrypt_xor, key)

with tf.Session().as_default() as session:
    session.run(tf.global_variables_initializer())
    print('key:'.ljust(24), int_list_to_hex(key.eval()))
    print('message:'.ljust(24), int_list_to_string(message.eval()))

    ciphertext = encrypt_xor.eval()
    print('encrypted ciphertext:'.ljust(24), int_list_to_hex(ciphertext))

    plaintext = decrypt_xor.eval()
    print('decrypted plaintext:'.ljust(24), int_list_to_string(plaintext))

DES

The Data Encryption Standard, or DES, was the crypto workhorse of the 1970’s-1990’s. The core of the algorithm is a Feistel Network, which is a construction that lets you create an invertible function (i.e. a function that can encrypt a message and decrypt it to the same message) out of a non-invertible pseudo-random function (for instance a hash function, but with the same output size as the input).

The Feistel Network works by splitting the input into two halves (a left half and a right half) and feeding those halves through 16 rounds, as illustrated in the image to the right.

Given a pseudo-random function , the next round of the encryption algorithm (left half: , right half: ) is computed as:

To decrypt, run algorithm in reverse:

BLOCK_SIZE = 32
NUM_ROUNDS = 16

def feistel_network_encrypt_round(round_key, left_0, right_0):
    """Run one encryption round of a Feistel network.

    Args:
        round_key: The PRF is keyed with this round key.
        left_0: the left half of the input.
        right_0: the right half of the input.
    Returns:
        right n+1: the right half ouput.
        left n+1: the left half output.
    """
    # (Using bitwise inversion instead of a true PRF)
    f_ri_ki = tf.bitwise.invert(right_0)
    right_plusone = tf.bitwise.bitwise_xor(left_0, f_ri_ki)

    return right_0, right_plusone


def feistel_network_decrypt_round(round_key, left_plusone, right_plusone):
    """Run one decryption round of a Feistel network.

    Args:
        round_key: The PRF is keyed with this round key.
        left_plusone: the preceding left half of the input.
        right_plusone: the precedingright half of the input.
    Returns:
        left n-1: the decrypted left half.
        right n-1: the decrypted right half.
    """
    # (Using bitwise inversion instead of a true PRF)
    f_lip1_ki = tf.bitwise.invert(left_plusone)
    right_0 = tf.bitwise.bitwise_xor(right_plusone, f_lip1_ki)

    return right_0, right_plusone

def pkcs7_pad(text):
    # Not true PKCS #7 padding, only for demo purposes.
    val = BLOCK_SIZE - (len(text) % BLOCK_SIZE)
    return text + ('%d' % val) * val

def pkcs7_unpad(text):
    val = text[-1]
    return text[:(len(text) - int(text[-1]))]

message_str = pkcs7_pad("Hello this is a secret message.")
input_tensor = tf.constant([ord(c) for c in message_str], tf.uint8)

key_uint32 = tf.Variable(tf.random_uniform((NUM_ROUNDS,), minval=0, maxval=2**8, dtype=tf.int32))
key = tf.cast(key_uint32, tf.uint8)

with tf.Session().as_default() as session:
    session.run(tf.global_variables_initializer())

    # Keys here are used to seed the random shuffle.
    # Key is 16 bytes, one byte per round.
    # (Note: this does not follow the DES key scheduling algorithm).
    print('key:'.ljust(24), int_list_to_hex(key.eval()))
    print('padded message:'.ljust(24), int_list_to_string(input_tensor.eval()))
    
    # Encryption: split the input in half and run the network for 16 rounds.
    left, right = tf.split(input_tensor, num_or_size_splits=2)
    
    for round_num in range(NUM_ROUNDS):
        right, left = feistel_network_encrypt_round(key[round_num], left, right)
    
    print('encrypted ciphertext:'.ljust(24), int_list_to_hex(left.eval()) + int_list_to_hex(right.eval()))

    # Decryption: run the network in reverse.
    for round_num in range(NUM_ROUNDS):
        left, right = feistel_network_decrypt_round(key[round_num], left, right)
    
    print('decrypted plaintext:'.ljust(24), pkcs7_unpad(int_list_to_string(left.eval()) + int_list_to_string(right.eval())))

I don’t think there are any practical use cases for writing cryptographic algorithms in TensorFlow, unless maybe if you need to encrypt an extremely large input by distributing your computation across many nodes. And even then, it wouldn’t be smart to use TensorFlow for real cryptography. But I do think the usability of TensorFlow has lessons that could be transferred to the field of Cryptography. TensorFlow’s graph structure isn’t the easiest API to wrap your head around, but once you grok it, it does provide a clear way to define computational graphs.

Contents (top)