Example: Save and Load a TensorFlow Model

This post details how to save and load a TensorFlow model using the DNNClassifier API.

The key idea here is that you define a function or a class beforehand that takes a model directory (in which it will save and restore the model parameters), adds that to RunConfig, and returns a tf.contrib.learn.Estimator, for example, tf.contrib.learn.DNNClassifier. See make_estimator for more details.

import numpy as np
import tensorflow as tf

from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.training.python.training import hparam
MODEL_DIR = 'your-model-dir'
hparams = hparam.HParams(
    num_epochs=10, 
    train_batch_size=50, 
    eval_batch_size=50,
    eval_steps=10
)
def make_estimator(model_dir):
    config = run_config.RunConfig(model_dir=model_dir)
    
    input_columns = [
      tf.feature_column.numeric_column(key='random_feature'),
    ]

    return tf.contrib.learn.DNNClassifier(
        config=config,
        n_classes=2,
        feature_columns=input_columns,
        hidden_units=[1024, 512, 256],
    )
dataset_size = 1000
X_train = np.random.rand(dataset_size)
y_train = np.random.rand(dataset_size)
X_eval = np.random.rand(dataset_size)
y_eval = np.random.rand(dataset_size)

estimator = make_estimator(MODEL_DIR)

experiment = tf.contrib.learn.Experiment(
    estimator,
    train_input_fn=tf.estimator.inputs.numpy_input_fn(
        x={'random_feature': X_train},
        y=y_train,
        num_epochs=hparams.num_epochs,
        batch_size=hparams.train_batch_size,
        shuffle=True
    ),
    eval_input_fn=tf.estimator.inputs.numpy_input_fn(
        x={'random_feature': X_eval},
        y=y_eval,
        num_epochs=None,
        batch_size=hparams.eval_batch_size,
        shuffle=False # Don't shuffle evaluation data
    )
)
experiment.train()

Loading the model

estimator_from_file = make_estimator(MODEL_DIR)
X_predict = np.array([0.3, 0.4, 0.5])

predict_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'random_feature': X_predict},
    num_epochs=1,
    batch_size=hparams.eval_batch_size,
    shuffle=False # Don't shuffle evaluation data
)

predictions = estimator_from_file.predict_proba(input_fn=predict_input_fn)
for prediction in predictions:
    print('Prediction:', prediction)

Contents (top)

Comments