TensorFlow Examples

MNIST

The MNIST dataset has a training set of 60,000 examples and a test set of 10,000 examples of the handwritten digits. Each example is a 28 x 28-pixel monochrome image.

This sample shows the use of low-level APIs and tf.estimator.Estimator to build a simple convolution neural network classifier, and how we can use vai_p_tensorflow to prune it.

TensorFlow Low Level API

Download and Convert Dataset

Create a file called data_utils.py, and add the following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip, os, sys
from six.moves import urllib

import numpy as np
import tensorflow as tf

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

# The URLs where the MNIST data can be downloaded.
_DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
_TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz'
_TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
_TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz'
_TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
_LABELS_FILENAME = 'labels.txt'
_DATASET_DIR = 'data/mnist'

_IMAGE_SIZE = 28
_NUM_CHANNELS = 1
_NUM_LABELS = 10

# The names of the classes.
_CLASS_NAMES = [
    'zero',
    'one',
    'two',
    'three',
    'four',
    'five',
    'size',
    'seven',
    'eight',
    'nine',
]

def _extract_images(filename, num_images):
  """Extract the images into a numpy array.

  Args:
    filename: The path to an MNIST images file.
    num_images: The number of images in the file.

  Returns:
    A numpy array of shape [number_of_images, height, width, channels].
  """
  print('Extracting images from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(16)
    buf = bytestream.read(
        _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  return data


def _extract_labels(filename, num_labels):
  """Extract the labels into a vector of int64 label IDs.

  Args:
    filename: The path to an MNIST labels file.
    num_labels: The number of labels in the file.

  Returns:
    A numpy array of shape [number_of_labels]
  """
  print('Extracting labels from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(8)
    buf = bytestream.read(1 * num_labels)
    labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
  return labels

def int64_feature(values):
  """Returns a TF-Feature of int64s.

  Args:
    values: A scalar or list of values.

  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
  """Returns a TF-Feature of bytes.

  Args:
    values: A string.

  Returns:
    A TF-Feature.
  """
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def _image_to_tfexample(image_data, class_id):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/class/label': int64_feature(class_id)
  }))

def _add_to_tfrecord(data_filename, labels_filename, num_images,
                     tfrecord_writer):
  """Loads data from the binary MNIST files and writes files to a TFRecord.

  Args:
    data_filename: The filename of the MNIST images.
    labels_filename: The filename of the MNIST labels.
    num_images: The number of images in the dataset.
    tfrecord_writer: The TFRecord writer to use for writing.
  """
  images = _extract_images(data_filename, num_images)
  labels = _extract_labels(labels_filename, num_images)

  shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  with tf.Graph().as_default():
    image = tf.placeholder(dtype=tf.uint8, shape=shape)
    encoded_png = tf.image.encode_png(image)

    with tf.Session('') as sess:
      for j in range(num_images):
        sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
        sys.stdout.flush()

        png_string = sess.run(encoded_png, feed_dict={image: images[j]})
        example = _image_to_tfexample(png_string, labels[j])
        tfrecord_writer.write(example.SerializeToString())


def _get_output_filename(dataset_dir, split_name):
  """Creates the output filename.

  Args:
    dataset_dir: The directory where the temporary files are stored.
    split_name: The name of the train/test split.

  Returns:
    An absolute file path.
  """
  return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name)


def _download_dataset(dataset_dir):
  """Downloads MNIST locally.

  Args:
    dataset_dir: The directory where the temporary files are stored.
  """
  for filename in [_TRAIN_DATA_FILENAME,
                   _TRAIN_LABELS_FILENAME,
                   _TEST_DATA_FILENAME,
                   _TEST_LABELS_FILENAME]:
    filepath = os.path.join(dataset_dir, filename)

    if not os.path.exists(filepath):
      print('Downloading file %s...' % filename)
      def _progress(count, block_size, total_size):
        sys.stdout.write('\r>> Downloading %.1f%%' % (
            float(count * block_size) / float(total_size) * 100.0))
        sys.stdout.flush()
      filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename,
                                               filepath,
                                               _progress)
      print()
      with tf.gfile.GFile(filepath) as f:
        size = f.size()
      print('Successfully downloaded', filename, size, 'bytes.')

def _write_label_file(labels_to_class_names, dataset_dir,
                     filename=_LABELS_FILENAME):
  """Writes a file with the list of class names.

  Args:
    labels_to_class_names: A map of (integer) labels to class names.
    dataset_dir: The directory in which the labels file should be written.
    filename: The filename where the class names are written.
  """
  labels_filename = os.path.join(dataset_dir, filename)
  with tf.gfile.Open(labels_filename, 'w') as f:
    for label in labels_to_class_names:
      class_name = labels_to_class_names[label]
      f.write('%d:%s\n' % (label, class_name))

def _clean_up_temporary_files(dataset_dir):
  """Removes temporary files used to create the dataset.

  Args:
    dataset_dir: The directory where the temporary files are stored.
  """
  for filename in [_TRAIN_DATA_FILENAME,
                   _TRAIN_LABELS_FILENAME,
                   _TEST_DATA_FILENAME,
                   _TEST_LABELS_FILENAME]:
    filepath = os.path.join(dataset_dir, filename)
    tf.gfile.Remove(filepath)


def download_and_convert(dataset_dir, clean=False):
  """Runs the download and conversion operation.

  Args:
    dataset_dir: The dataset directory where the dataset is stored.
  """
  if not tf.gfile.Exists(dataset_dir):
    tf.gfile.MakeDirs(dataset_dir)

  training_filename = _get_output_filename(dataset_dir, 'train')
  testing_filename = _get_output_filename(dataset_dir, 'test')

  if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
    print('Dataset files already exist. Exiting without re-creating them.')
    return

  _download_dataset(dataset_dir)

  # First, process the training data:
  with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
    data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME)
    labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME)
    _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)

  # Next, process the testing data:
  with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
    data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME)
    labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME)
    _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)

  # Finally, write the labels file:
  labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
  _write_label_file(labels_to_class_names, dataset_dir)

  if clean:
      _clean_up_temporary_files(dataset_dir)
  print('\nFinished converting the MNIST dataset!')

def _parse_function(tfrecord_serialized):
  """Parse TFRecord serialized object into image and label with specified shape
  and data type.

  Args:
    TFRecord_serialized: tf.data.TFRecordDataset.

  Returns:
    Parsed image and label 
  """
  features = {'image/encoded': tf.FixedLenFeature([], tf.string),
              'image/class/label': tf.FixedLenFeature([], tf.int64)}
  parsed_features = tf.parse_single_example(tfrecord_serialized, features)
  image = parsed_features['image/encoded']
  label = parsed_features['image/class/label']
  image = tf.image.decode_png(image)
  image = tf.divide(image, 255)
  return image, label

def get_init_data(train_batch, 
             test_batch, 
             dataset_dir=_DATASET_DIR,
             test_only=False,
             num_parallel_calls=8):
  """Build input data pipline, which must be initial by sess.run(init)

  Args:
    train_batch: batch size of train data set
    test_batch: batch size of test data set
    dataset_dir: Optional. Where to store data set
    test_only: If only build test data input pipline set
    num_parallel_calls: number of parallel read data

  Returns:
    img: input image data tensor
    label: input label data tensor
    train_init: train data initializer
    test_init:test data initializer
  """
  with tf.name_scope('data'):
    testing_filename = _get_output_filename(dataset_dir, 'test')
    test_data = tf.data.TFRecordDataset(testing_filename)
    test_data = test_data.map(_parse_function, \
            num_parallel_calls=num_parallel_calls)
    test_data = test_data.batch(test_batch)
    test_data = test_data.prefetch(test_batch)

    iterator = tf.data.Iterator.from_structure(test_data.output_types, 
                                               test_data.output_shapes)
    test_init = iterator.make_initializer(test_data)    # initializer for train_data
    img, label = iterator.get_next()
    # reshape the image from [28,28,1], to make it work with tf.nn.conv2d
    img = tf.reshape(img, shape=[-1, _IMAGE_SIZE , _IMAGE_SIZE , _NUM_CHANNELS])
    label = tf.one_hot(label, _NUM_LABELS)

    train_init = None
    if not test_only:
      training_filename = _get_output_filename(dataset_dir, 'train')
      train_data = tf.data.TFRecordDataset([training_filename])
      train_data = train_data.shuffle(10000) 
      train_data = train_data.map(_parse_function,\
              num_parallel_calls=num_parallel_calls)
      train_data = train_data.batch(train_batch)
      train_data = train_data.prefetch(train_batch)
      train_init = iterator.make_initializer(train_data)  # initializer for train_data
    return img, label, train_init, test_init

def get_one_shot_test_data(
        test_batch, 
        dataset_dir=_DATASET_DIR,
        num_parallel_calls=8):
  """Build input test data pipline, which no need to be initial. For `vai_p_tensorflow
  --ana`

  Args:
    test_batch: batch size of test data set
    dataset_dir: Optional. Where to store data set
    num_parallel_calls: number of parallel read data

  Returns:
    img: input image data tensor
    label: input label data tensor
  """
  #do not need initial
  with tf.name_scope('data'):
    testing_filename = _get_output_filename(dataset_dir, 'test')
    test_data = tf.data.TFRecordDataset([testing_filename])
    test_data = test_data.map(_parse_function, 
                              num_parallel_calls=num_parallel_calls)
    test_data = test_data.batch(test_batch)
    test_data = test_data.prefetch(test_batch)

    iterator = test_data.make_one_shot_iterator()
    img, label = iterator.get_next()
    # reshape the image from [28,28,1] to make it work with tf.nn.conv2d
    img = tf.reshape(img, shape=[-1, _IMAGE_SIZE , _IMAGE_SIZE , _NUM_CHANNELS])
    label = tf.one_hot(label, _NUM_LABELS)
    return img, label

if __name__ == '__main__':
  download_and_convert(_DATASET_DIR)

The dataset_utils supply function calls get_init_data taking train_batch and test_batch as arguments and returns an image, label tensors, and initializer operations for train data and test data respectively, which will now run in training and evaluating.

The data_utils.py is imported as a module to provide input data pipeline. You can also run it in shell to download the MNIST dataset and convert it into TFRecord format using the following command:

$ python data_utils.py

This generates the following:

data/minist/label.txt
data/minist/mnist_test.tfrecord data/minist/mnist_train.tfrecord
data/minist/t10k-images-idx3-ubyte.gz
data/minist/t10k-labels-idx1-ubyte.gz
data/minist/train-images-idx3-ubyte.gz
data/minist/train-labels-idx1-ubyte.gz

Build the CNN MNIST Classifier

Create a file called low_level_cnn.py, and add the following code:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from data_utils import get_one_shot_test_data

TEST_BATCH=100

def conv_relu(inputs, filters, k_size, stride, padding, scope_name):
    '''
    A method that does convolution + relu on inputs
    '''
    with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
        in_channels = inputs.shape[-1]
        kernel = tf.get_variable('kernel', 
                                [k_size, k_size, in_channels, filters], 
                                initializer=tf.truncated_normal_initializer())
        biases = tf.get_variable('biases', 
                                [filters],
                                initializer=tf.random_normal_initializer())
        conv = tf.nn.conv2d(inputs, kernel, strides=[1, stride, stride, 1], padding=padding)
    return tf.nn.relu(tf.nn.bias_add(conv, biases), name=scope.name)

def maxpool(inputs, ksize, stride, padding='VALID', scope_name='pool'):
    '''A method that does max pooling on inputs'''
    with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
        pool = tf.nn.max_pool(inputs, 
                            ksize=[1, ksize, ksize, 1], 
                            strides=[1, stride, stride, 1],
                            padding=padding)
    return pool

def fully_connected(inputs, out_dim, scope_name='fc'):
    '''
    A fully connected linear layer on inputs
    '''
    with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
        in_dim = inputs.shape[-1]
        w = tf.get_variable('weights', [in_dim, out_dim],
                            initializer=tf.truncated_normal_initializer())
        b = tf.get_variable('b', [out_dim],
                            initializer=tf.constant_initializer(0.0))
        out = tf.matmul(inputs, w) + b
    return out

def net_fn(image, n_classes=10, keep_prob=0.5, is_training=True):
    conv1 = conv_relu(inputs=image,
                    filters=32,
                    k_size=5,
                    stride=1,
                    padding='SAME',
                    scope_name='conv1')
    pool1 = maxpool(conv1, 2, 2, 'VALID', 'pool1')
    conv2 = conv_relu(inputs=pool1,
                    filters=64,
                    k_size=5,
                    stride=1,
                    padding='SAME',
                    scope_name='conv2')
    pool2 = maxpool(conv2, 2, 2, 'VALID', 'pool2')
    feature_dim = pool2.shape[1] * pool2.shape[2] * pool2.shape[3]
    pool2 = tf.reshape(pool2, [-1, feature_dim])
    fc = fully_connected(pool2, 1024, 'fc')
    keep_prob = keep_prob if is_training else 1
    dropout = tf.nn.dropout(tf.nn.relu(fc), keep_prob, name='relu_dropout')
    logits = fully_connected(dropout, n_classes, 'logits')
    return logits
net_fn.default_image_size=28

def model_fn():
  tf.logging.set_verbosity(tf.logging.INFO)
  img, labels = get_one_shot_test_data(TEST_BATCH)

  logits = net_fn(img, is_training=False)
  predictions = tf.argmax(logits, 1)
  labels = tf.argmax(labels, 1)
  eval_metric_ops = {
      'accuracy': tf.metrics.accuracy(labels, predictions),
      'recall_5': tf.metrics.recall_at_k(labels, logits, 5)
  }
  return eval_metric_ops

The net_fn function defines the network architecture. It takes MNIST image data as arguments and return a logits tensor. Function model_fn read an input data pipeline and returns a dictionary of evaluation metrics operations.

Model Building, Training and Evaluating

Create a file called train_eval_utils.py, and add the following code:

import os, time, sys
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

import tensorflow as tf

from low_level_cnn import net_fn
from data_utils import get_init_data

class ConvNet(object):
  def __init__(self, training=True):
    self.lr = 0.001
    self.train_batch = 128
    self.test_batch = 100
    self.keep_prob = tf.constant(0.75)
    self.gstep = tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')
    self.n_classes = 10
    self.skip_step = 100
    self.n_test = 10000
    self.training = training

  def loss(self):
    '''
    define loss function
    use softmax cross entropy with logits as the loss function
    compute mean cross entropy, softmax is applied internally
    '''
    with tf.name_scope('loss'):
      entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.label, logits=self.logits)
      self.loss = tf.reduce_mean(entropy, name='loss')

  def optimize(self):
    '''
    Define training op
    using Adam optimizer to minimize cost
    '''
    self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.gstep)

  def eval(self):
    '''
    Count the number of right predictions in a batch
    '''
    with tf.name_scope('predict'):
      preds = tf.nn.softmax(self.logits)
      correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))
      self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))

  def summary(self):
    '''
    Create summaries to write on TensorBoard
    '''
    with tf.name_scope('summaries'):
      tf.summary.scalar('accuracy', self.accuracy)
      if self.training:
        tf.summary.scalar('loss', self.loss)
        tf.summary.histogram('histogram_loss', self.loss)
      self.summary_op = tf.summary.merge_all()

  def build(self, test_only=False):
    '''
    Build the computation graph
    '''
    self.img, self.label, self.train_init, self.test_init = \
            get_init_data(self.train_batch, self.test_batch, test_only=test_only)

    self.logits = net_fn(self.img, n_classes=self.n_classes, \
            keep_prob=self.keep_prob, is_training=self.training)
    if self.training:
      self.loss()
      self.optimize()
    self.eval()
    self.summary()

  def train_one_epoch(self, sess, saver, writer, epoch, step):
    start_time = time.time()
    sess.run(self.train_init)
    total_loss = 0
    n_batches = 0
    tf.logging.info(time.strftime('time:%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
    try:
      while True:
        _, l, summaries = sess.run([self.opt, self.loss, self.summary_op])
        writer.add_summary(summaries, global_step=step)
        if (step + 1) % self.skip_step == 0:
          tf.logging.info('Loss at step {0}: {1}'.format(step+1, l))
        step += 1
        total_loss += l
        n_batches += 1
    except tf.errors.OutOfRangeError:
      pass
    #saver.save(sess, 'checkpoints/convnet_mnist/mnist-convnet', step)
    tf.logging.info('Average loss at epoch {0}: {1}'.format(epoch, total_loss/n_batches))
    tf.logging.info('train one epoch took: {0} seconds'.format(time.time() - start_time))
    return step

  def eval_once(self, sess, writer=None, step=None):
    start_time = time.time()
    sess.run(self.test_init)
    total_correct_preds = 0
    eval_step = 0
    try:
      while True:
        eval_step += 1
        accuracy_batch, summaries = sess.run([self.accuracy, self.summary_op])
        writer.add_summary(summaries, global_step=step) if writer else None
        total_correct_preds += accuracy_batch
    except tf.errors.OutOfRangeError:
      pass
    tf.logging.info('Evaluation took: {0} seconds'.format(time.time() - start_time))
    tf.logging.info('Accuracy : {0} \n'.format(total_correct_preds/self.n_test))

  def train_eval(self, n_epochs=10, save_ckpt=None, restore_ckpt=None):
    '''
    The train function alternates between training one epoch and evaluating
    '''
    if restore_ckpt:
      writer = tf.summary.FileWriter('./graphs/convnet/finetune', tf.get_default_graph())
    else:
      writer = tf.summary.FileWriter('./graphs/convnet/train', tf.get_default_graph())
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      saver = tf.train.Saver()
      if restore_ckpt:
        saver.restore(sess, restore_ckpt)
      step = self.gstep.eval()
      for epoch in range(n_epochs):
        step = self.train_one_epoch(sess, saver, writer, epoch, step)
        self.eval_once(sess, writer, step)
      saver.save(sess, save_ckpt)
    writer.close()
    tf.logging.info("Finish")

  def evaluate(self, restore_ckpt):
    '''
    The evaluating function
    '''
    with tf.Session() as sess:
      saver = tf.train.Saver()
      saver.restore(sess, restore_ckpt)
      step = self.gstep.eval()
      self.eval_once(sess)
    tf.logging.info("Finish")

ConvNet is a class which can build graph and train and evaluate model. It is a framework by combining the data utils, net definition, and metrics. To train and evaluate a model, instantiate a ConvNet class, then call the class method build to build, train, or evaluate a graph by specifying if the test_only argument is true.

Train the Model

To train the model, create a file named train.py and add following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from train_eval_utils import ConvNet

tf.app.flags.DEFINE_string(
    'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS

def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info("Training model from scratch")
  net = ConvNet(True)
  net.build()
  net.train_eval(10, FLAGS.save_ckpt)

if __name__ == '__main__':
  tf.app.run()

Run train.py in shell:

$ WORKSPACE=./models
$ BASELINE_CKPT=${WORKSPACE}/train/model.ckpt
$ mkdir -p $(dirname "${BASELINE_CKPT}")
$ python train.py --save_ckpt=${BASELINE_CKPT}

The running output log looks like this:

INFO:tensorflow:time:2019-01-09 16:14:44
INFO:tensorflow:Loss at step 500: 421.8246154785156
INFO:tensorflow:Loss at step 600: 305.761474609375
INFO:tensorflow:Loss at step 700: 167.25115966796875
INFO:tensorflow:Loss at step 800: 399.25732421875
INFO:tensorflow:Loss at step 900: 246.51300048828125
INFO:tensorflow:Average loss at epoch 1: 390.06004813383385
INFO:tensorflow:train one epoch took: 2.353825569152832 seconds
INFO:tensorflow:Evaluation took: 0.22740554809570312 seconds
INFO:tensorflow:Accuracy : 0.9435

After a few minutes, you get a trained checkpoint: models/train/model.ckpt.

Export an Inference GraphDef File

Create a file named export_inf_graph.py and add the following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.python.platform import gfile
from google.protobuf import text_format
from low_level_cnn import net_fn

tf.app.flags.DEFINE_integer(
    'image_size', None,
    'The image size to use, otherwise use the model default_image_size.')

tf.app.flags.DEFINE_integer(
    'batch_size', None,
    'Batch size for the exported model. Defaulted to "None" so batch size can '
    'be specified at model runtime.')

tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
                           'The name of the dataset to use with the model.')

tf.app.flags.DEFINE_string(
    'output_file', '', 'Where to save the resulting file to.')

FLAGS = tf.app.flags.FLAGS

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)

  with tf.Graph().as_default() as graph:
    network_fn = net_fn
    image_size = FLAGS.image_size or network_fn.default_image_size
    image = tf.placeholder(name='image', dtype=tf.float32, \
                             shape=[FLAGS.batch_size, image_size, image_size, 1])
    network_fn(image, is_training=False)
    graph_def = graph.as_graph_def()

    with gfile.GFile(FLAGS.output_file, 'w') as f:
      f.write(text_format.MessageToString(graph_def))
    tf.logging.info("Finish export inference graph")

if __name__ == '__main__':
    tf.app.run()

Run export_inf_graph.py.

$ WORKSPACE=./models
$ BASELINE_GRAPH=${WORKSPACE}/mnist.pbtxt
$ python export_inf_graph.py --output_file=${BASELINE_GRAPH}

Run Model Analysis

Now that you have prepared a trained checkpoint and a GraphDef file, you can start the pruning process. Run the following shell scripts to call the vai_p_tensorflow functions.

WORKSPACE=./models
BASELINE_GRAPH=${WORKSPACE}/mnist.pbtxt
BASELINE_CKPT=${WORKSPACE}/train/model.ckpt
INPUT_NODES="image"
OUTPUT_NODES="logits/add"
action=ana

vai_p_tensorflow \
  --action=${action} \
  --input_graph=${BASELINE_GRAPH} \
  --input_ckpt=${BASELINE_CKPT} \
  --eval_fn_path=low_level_cnn.py \
  --target="accuracy" \
  --max_num_batches=100 \
  --workspace=${WORKSPACE} \
  --input_nodes="${INPUT_NODES}" \
  --input_node_shapes="1,28,28,1" \
  --output_nodes=\"${OUTPUT_NODES}\"

The output log is as shown below:

INFO:tensorflow:Starting evaluation at 2019-01-09-08:43:15
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./models/train/model.ckpt
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2019-01-09-08:43:21

Prune the Model

You can prune the model now and write some shell scripts to call the vai_p_tensorflow functions.

WORKSPACE=./models
BASELINE_GRAPH=${WORKSPACE}/mnist.pbtxt
BASELINE_CKPT=${WORKSPACE}/train/model.ckpt
PRUNED_GRAPH=${WORKSPACE}/pruned/graph.pbtxt
PRUNED_CKPT=${WORKSPACE}/pruned/sparse.ckpt
INPUT_NODES="image"
OUTPUT_NODES="logits/add"
action=prune

mkdir -p $(dirname "${PRUNED_GRAPH}")
vai_p_tensorflow \
  --action=${action} \
  --input_graph=${BASELINE_GRAPH} \
  --input_ckpt=${BASELINE_CKPT} \
  --output_graph=${PRUNED_GRAPH} \
  --output_ckpt=${PRUNED_CKPT} \
  --workspace=${WORKSPACE} \
  --input_nodes="${INPUT_NODES}" \
  --input_node_shapes="1,28,28,1" \
  --output_nodes="${OUTPUT_NODES}" \
  --sparsity=0.5 \
  --gpu="0,1,2,3" \
  2>&1 | tee prune.log

Finetune the Pruned Model

Create a file named ft.py and add the following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from train_eval_utils import ConvNet

tf.app.flags.DEFINE_string(
    'checkpoint_path', '', 'Where to restore checkpoint.')
tf.app.flags.DEFINE_string(
    'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS

def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info("Finetuning model")

  tf.set_pruning_mode()
  net = ConvNet(True)
  net.build()
  net.train_eval(10, FLAGS.save_ckpt, FLAGS.checkpoint_path)

if __name__ == '__main__':
  tf.app.run()
Note: You must call tf.set_pruning_mode() before creating the model. The API is used to enable “sparse training” mode, that is, the weights of pruned channels will be kept to 0 and will not be updated during training. If you fine-tune a pruned model without calling this function, the pruned channels will be updated and finally you will get a normal non-sparse model.

Finetune the pruned model is similar to train model from scratch and run ft.py:

WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
PRUNED_CKPT=${WORKSPACE}/pruned/sparse.ckpt
python -u ft.py \
    --save_ckpt=${FT_CKPT} \
    --checkpoint_path=${PRUNED_CKPT} \
    2>&1 | tee ft.log

The output log looks like:

INFO:tensorflow:time:2019-01-09 17:17:10
INFO:tensorflow:Loss at step 1000: 13.077235221862793
INFO:tensorflow:Loss at step 1100: 41.67073440551758
INFO:tensorflow:Loss at step 1200: 31.98809242248535
INFO:tensorflow:Loss at step 1300: 34.46034240722656
INFO:tensorflow:Loss at step 1400: 32.12882995605469
INFO:tensorflow:Average loss at epoch 2: 28.96098704302489
INFO:tensorflow:train one epoch took: 3.0082509517669678 seconds
INFO:tensorflow:Evaluation took: 0.23403644561767578 seconds
INFO:tensorflow:Accuracy : 0.9539

As a final step, you need to transform and freeze the fine-tuned model to get a dense model.

WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
PRUNED_GRAPH=${WORKSPACE}/pruned/graph.pbtxt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb
OUTPUT_NODES="logits/add"

vai_p_tensorflow \
    --action=transform \
    --input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}


freeze_graph \
    --input_graph="${PRUNED_GRAPH}" \
    --input_checkpoint="${TRANSFORMED_CKPT}" \
    --input_binary=false \
    --output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}

Finally, you should have a frozen GraphDef file named mninst.pb in models/pruned.

Estimator

Build the CNN MNIST Classifier

Create a file named est_cnn.py and add the following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Imports
import numpy as np
import tensorflow as tf

# Our application logic will be added here
def cnn_model_fn(features, labels, mode):
  """Model function for CNN."""
  # Input Layer
  input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

  # Convolutional Layer #1
  conv1 = tf.layers.conv2d( 
      inputs=input_layer,
      filters=32,
      kernel_size=[5, 5],
      padding="same",
      activation=tf.nn.relu)

  # Pooling Layer #1
  pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

  # Convolutional Layer #2 and Pooling Layer #2
  conv2 = tf.layers.conv2d(
      inputs=pool1,
      filters=64,
      kernel_size=[5, 5],
      padding="same",
      activation=tf.nn.relu)
  pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

  # Dense Layer
  pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
  dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
  dropout = tf.layers.dropout(
      inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)

  # Logits Layer
  logits = tf.layers.dense(inputs=dropout, units=10)

  predictions = {
      # Generate predictions (for PREDICT and EVAL mode)
      "classes": tf.argmax(input=logits, axis=1),
      # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
      # `logging_hook`.
      "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
  }

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  # Calculate Loss (for both TRAIN and EVAL modes)
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

  # Configure the Training Op (for TRAIN mode)
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    train_op = optimizer.minimize(
        loss=loss,
        global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

  # Add evaluation metrics (for EVAL mode)
  eval_metric_ops = {
      "accuracy": tf.metrics.accuracy(
          labels=labels, predictions=predictions["classes"])}
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

# Load training and eval data
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

def train_input_fn():
  return tf.estimator.inputs.numpy_input_fn(
      x={"x": train_data},
      y=train_labels,
      batch_size=100,
      num_epochs=None,
      shuffle=True)

def eval_input_fn():
  return tf.estimator.inputs.numpy_input_fn(
      x={"x": eval_data},
      y=eval_labels,
      num_epochs=1,
      shuffle=False)

def model_fn():
  return tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir="./models/train/")

The cnn_model_fn function conforms to the interface expected by the Estimator API of TensorFlow. It takes MNIST feature data, labels and mode as arguments; create convolution and activation layers, and returns predictions, loss, and a training operation.

train_input_fn and eval_input_fn are functions that provide data to the network during training and evaluation respectively.

Train Baseline Model

To train the model by creating an Estimator and calling train() on it, create a file named train.py and add following codes:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from est_cnn import model_fn, train_input_fn, eval_input_fn

# Imports
import numpy as np
import tensorflow as tf

tf.logging.set_verbosity(tf.logging.INFO)

def main(unused_argv):
  mnist_classifier = model_fn()

  mnist_classifier.train(
      input_fn=train_input_fn(),
      max_steps=20000)

  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn())
  print(eval_results)

if __name__ == "__main__":
  tf.app.run()

Run train.py.

$ python train.py

As the model trains, an output similar to the following is displayed:

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./models/train/model.ckpt.
INFO:tensorflow:loss = 2.294087, step = 0
INFO:tensorflow:global_step/sec: 201.741
INFO:tensorflow:loss = 2.2876544, step = 100 (0.496 sec)
INFO:tensorflow:global_step/sec: 228.126
INFO:tensorflow:loss = 2.2656975, step = 200 (0.439 sec)
INFO:tensorflow:global_step/sec: 225.094
INFO:tensorflow:loss = 2.2483034, step = 300 (0.444 sec)
INFO:tensorflow:global_step/sec: 234.019
…
INFO:tensorflow:Saving dict for global step 20000: accuracy = 0.9684, global_step = 20000, loss = 0.10172604
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20000: ./models/train/model.ckpt-20000
{'accuracy': 0.9684, 'loss': 0.10172604, 'global_step': 20000}

You can get an accuracy of 96.84% on our test data set.

Export an Inference GraphDef File

Create a file named export_inf_graph.py and add the following code:

from google.protobuf import text_format
from est_cnn import cnn_model_fn
from tensorflow.keras import backend as K
from tensorflow.python.platform import gfile
import tensorflow as tf

tf.app.flags.DEFINE_string(
    'output_file', '', 'Where to save the resulting file to.')

FLAGS = tf.app.flags.FLAGS

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)

  with tf.Graph().as_default() as graph:
    image = tf.placeholder(name='image', dtype=tf.float32,
                           shape=[1, 28, 28, 1])
    label = tf.placeholder(name='label', dtype=tf.int32, shape=[1])

    cnn_model_fn({"x": image}, label, tf.estimator.ModeKeys.EVAL)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'w') as f:
      f.write(text_format.MessageToString(graph_def))
    print("Finish export inference graph")

if __name__ == '__main__':
  tf.app.run()

Run Model Analysis

Now that you have prepared a trained checkpoint and a GraphDef file, you can start pruning.

Write some shell scripts to call vai_p_tensorflow functions.

WORKSPACE=./models

BASELINE_GRAPH=${WORKSPACE}/mnist.pbtxt
BASELINE_CKPT=${WORKSPACE}/train/model.ckpt-20000
INPUT_NODES="image"
OUTPUT_NODES="softmax_tensor"

action=ana
vai_p_tensorflow \
    --action=${action} \
    --input_graph=${BASELINE_GRAPH} \
    --input_ckpt=${BASELINE_CKPT} \
    --eval_fn_path=est_cnn.py \
    --target="accuracy" \
    --max_num_batches=500 \
    --workspace=${WORKSPACE} \
    --input_nodes="${INPUT_NODES}" \
    --input_node_shapes="1,28,28,1" \
--output_nodes=\"${OUTPUT_NODES}\"

You have previously defined an operation of tf.metrics.accuracy named “accuracy” to calculate the accuracy of your model in est_cnn.py:

eval_metric_ops = {
      "accuracy": tf.metrics.accuracy(
          labels=labels, predictions=predictions["classes"])}

Use this operation to evaluate the performance of your model by setting --target=”accuracy”.

Prune the Model

PRUNED_GRAPH=${WORKSPACE}/pruned/graph.pbtxt
PRUNED_CKPT=${WORKSPACE}/pruned/sparse.ckpt

action=prune
vai_p_tensorflow \
    --action=${action} \
    --input_graph=${BASELINE_GRAPH} \
    --input_ckpt=${BASELINE_CKPT} \
    --output_graph=${PRUNED_GRAPH} \
    --output_ckpt=${PRUNED_CKPT} \
    --workspace=${WORKSPACE} \
    --input_nodes="${INPUT_NODES}" \
    --input_node_shapes="1,28,28,1" \
    --output_nodes="${OUTPUT_NODES}" \
    --sparsity=0.2 \
--gpu="0,1,2,3"

Finetune the Pruned Model

Create a file named ft.py and add the following code:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from est_cnn import cnn_model_fn, train_input_fn

# Imports
import numpy as np
import tensorflow as tf

tf.app.flags.DEFINE_string(
    'checkpoint_path', None, 'Path of a specific checkpoint to finetune.')

FLAGS = tf.app.flags.FLAGS

tf.logging.set_verbosity(tf.logging.INFO)

def main(unused_argv):
  tf.set_pruning_mode()
  ws = tf.estimator.WarmStartSettings(
      ckpt_to_initialize_from=FLAGS.checkpoint_path)
  mnist_classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir="./models/ft/", warm_start_from=ws)

  mnist_classifier.train(
      input_fn=train_input_fn(),
      max_steps=20000)

if __name__ == "__main__":
  tf.app.run()

Use tf.estimator.WarmStartSettings to load pruned checkpoint and finetune from it.

Run ft.py to finetune the pruned model:

python -u ft.py --checkpoint_path=${PRUNED_CKPT}

The output log looks like the following:

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./models/ft/model.ckpt.
INFO:tensorflow:loss = 0.3675258, step = 0
INFO:tensorflow:global_step/sec: 162.673
INFO:tensorflow:loss = 0.31534952, step = 100 (0.615 sec)
INFO:tensorflow:global_step/sec: 210.058
INFO:tensorflow:loss = 0.2782951, step = 200 (0.476 sec)
...
INFO:tensorflow:loss = 0.022076223, step = 19800 (0.503 sec)
INFO:tensorflow:global_step/sec: 206.588
INFO:tensorflow:loss = 0.06927078, step = 19900 (0.484 sec)
INFO:tensorflow:Saving checkpoints for 20000 into ./models/ft/model.ckpt.
INFO:tensorflow:Loss for final step: 0.07726018.

As a final step, transform and freeze the finetuned model to get a dense model.

FT_CKPT=${WORKSPACE}/ft/model.ckpt-20000
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb

vai_p_tensorflow \
    --action=transform \
    --input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}

freeze_graph \
--input_graph="${PRUNED_GRAPH}" \
--input_checkpoint="${TRANSFORMED_CKPT}" \
--input_binary=false  \
--output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}

Finally, you have a frozen GraphDef file named mninst.pb.

VGG-16

This sample demonstrates how to run vai_p_tensorflow on real-world models. VGG (https://arxiv.org/abs/1409.1557) is a network for large-scale image recognition. This sample uses a pre-trained VGG-16 model from TensorFlow-Slim image classification model library.

  1. Download the TensorFlow-Slim repository and its pre-trained VGG16 model.
    $ git clone https://github.com/tensorflow/models.git 
    $ cd models/research/slim
    # mkdir models/vgg16 && cd models/vgg16
    $ wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
    $ tar xzvf vgg_16_2016_08_28.tar.gz
    
  2. Prepare the ImageNet dataset using the instructions onhttps://github.com/tensorflow/models/blob/master/research/inception/README.md#getting-started.
  3. Prepare a graph evaluation script named vgg16_eval.py.
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import math
    import tensorflow as tf
    
    from tensorflow.python.summary import summary
    from tensorflow.python.training import monitored_session
    from tensorflow.python.training import saver as tf_saver
    
    from datasets import dataset_factory
    from nets import nets_factory
    from preprocessing import preprocessing_factory
    
    slim = tf.contrib.slim
    
    dataset_name='imagenet'
    dataset_split_name='validation'
    dataset_dir='/dataset/imagenet/tf_records'
    model_name='vgg_16'
    labels_offset=1
    batch_size=100
    num_preprocessing_threads=4
    
    def model_fn():
      tf.logging.set_verbosity(tf.logging.INFO)
    
      tf_global_step = slim.get_or_create_global_step()
    
      ######################
      # Select the dataset #
      ######################
      dataset = dataset_factory.get_dataset(dataset_name,
                                                dataset_split_name,
                                                dataset_dir)
    
    ####################
      # Select the model #
      ####################
      network_fn = nets_factory.get_network_fn(
          model_name,
          num_classes=(dataset.num_classes - labels_offset),
          is_training=False)
    
      ##############################################################
      # Create a dataset provider that loads data from the dataset #
      ##############################################################
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          shuffle=False,
          common_queue_capacity=2 * batch_size,
          common_queue_min=batch_size)
      [image, label] = provider.get(['image', 'label'])
      label -= labels_offset
    
      #####################################
      # Select the preprocessing function #
      #####################################
      preprocessing_name = model_name
      image_preprocessing_fn = preprocessing_factory.get_preprocessing(
          preprocessing_name,
          is_training=False)
    
      eval_image_size = network_fn.default_image_size
    
      image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
    
      images, labels = tf.train.batch(
          [image, label],
          batch_size=batch_size,
          num_threads=num_preprocessing_threads,
          capacity=5 * batch_size)
    
    ####################
      # Define the model # 
      ####################
      logits, _ = network_fn(images)
    
      variables_to_restore = slim.get_variables_to_restore()
    
      predictions = tf.argmax(logits, 1)
      org_labels = labels
      labels = tf.squeeze(labels)
    
      eval_metric_ops = {
          'top-1': slim.metrics.streaming_accuracy(predictions, labels),
          'top-5': slim.metrics.streaming_recall_at_k(logits, org_labels, 5)
      }
      return eval_metric_ops
    
  4. Modify models/research/slim/export_inference_graph.py to generate an output of a readable text file instead of a binary file.
    + from google.protobuf import text_format
    
    -   with gfile.GFile(FLAGS.output_file, 'wb') as f:
    -     f.write(graph_def.SerializeToString())
    +   with gfile.GFile(FLAGS.output_file, 'w') as f:
    +     f.write(text_format.MessageToString(graph_def))
    
  5. Export an inference graph.
    python export_inference_graph.py \
        --model_name=vgg_16 \
        --output_file=vgg_16_inf_graph.pbtxt \
        --dataset_dir=/opt/dataset/tf_records
    
  6. Run model analysis.
    vai_p_tensorflow \
      --action=ana \
      --input_graph=vgg_16_inf_graph.pbtxt \
      --input_ckpt=vgg_16.ckpt \
      --eval_fn_path=vgg_16_eval.py \
      --target=top-5 \
      --max_num_batches=500 \
      --workspace=/home/deephi/models/research/slim/models/vgg16 \
      --exclude="vgg_16/fc6/Conv2D, vgg_16/fc7/Conv2D, vgg_16/fc8/Conv2D" \
    --output_nodes="vgg_16/fc8/squeezed"
    

    In vgg_16_eval.py, a variable named batch_size with an initial value of 100 was defined. There are 50,000 examples in the validation set of ImageNet, so the max_num_steps is set to 500 to ensure that all the examples in the validation set are tested in evaluation.

    IMPORTANT: The nodes with vgg_16/fc prefix affect the number of output labels of the network. Exclude these nodes to prevent shape mismatch with the dataset.
  7. Run model pruning.
    vai_p_tensorflow \
      --action=prune \
      --input_graph=vgg_16_inf_graph.pbtxt \
      --input_ckpt=vgg_16.ckpt \
      --output_graph=sparse_graph.pbtxt \
      --output_ckpt=sparse.ckpt \
      --workspace=/home/deephi/models/research/slim/models/vgg16 \
      --sparsity=0.15 \
      --exclude="vgg_16/fc6/Conv2D, vgg_16/fc7/Conv2D, vgg_16/fc8/Conv2D" \
      --output_nodes="vgg_16/fc8/squeezed"
    
  8. Open models/research/slim/train_image_classifier.py and insert the following line in the beginning of the main() function.
    def main():
    +  tf.set_pruning_mode()
    
  9. Fine-tune the pruned model.
    python train_image_classifier.py \
        --model_name=vgg_16 \
        --train_dir=./models/vgg16/ft \
        --dataset_name=imagenet \
        --dataset_dir=/opt/dataset/tf_records \
        --dataset_split_name=train \
        --checkpoint_path=./models/vgg16/sparse.ckpt \
        --labels_offset=0 \
        --save_interval_secs=600 \
        --batch_size=32 \
        --num_clones=4 \
        --weight_decay=5e-4 \
        --optimizer=adam \
        --learning_rate=1e-2 \
        --learning_rate_decay_type=polynomial \
        --decay_steps=200000 \
    --max_number_of_steps=200000
    
  10. Get a dense checkpoint and freeze graph.
    vai_p_tensorflow \
    --action=transform \
    --input_ckpt=./models/vgg16/ft/model.ckpt-200000 \
    --output_ckpt=./models/vgg16/dense.ckpt
    
    freeze_graph.py \
    --input_graph=./models/vgg16/sparse_graph.pbtxt \
    --input_checkpoint=./models/vgg16/dense.ckpt \
    --input_binary=false  \
    --output_graph=./models/vgg16/vgg16_pruned.pb \
    --output_node_names=”vgg_16/fc8/squeezed"
    

ResNet50

This example demonstrates how to prune a Keras model. A pre-defined ResNet50 is used here.

  1. Prepare evaluation script for model analysis, named ResNet50_model.py.
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import tensorflow as tf
    import time
    from preprocessing.dataset import input_fn, NUM_IMAGES
    
    TRAIN_NUM = NUM_IMAGES['train']
    EVAL_NUM = NUM_IMAGES['validation']
    
    DATASET_DIR="/scratch/workspace/dataset/imagenet/tf_records"
    batch_size = 100
    image_size = 224
    def get_input_data(prefix_preprocessing="vgg"):
        eval_data = input_fn(
            is_training=False, data_dir=DATASET_DIR,
            output_width=image_size,
            output_height=image_size,
            batch_size=batch_size,
            num_epochs=1,
            num_gpus=1,
            dtype=tf.float32,
            prefix_preprocessing=prefix_preprocessing)
        return eval_data
    
    network_fn = tf.keras.applications.ResNet50(weights=None,
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    
    def evaluate(ckpt_path=''):
        network_fn.load_weights(ckpt_path)
        metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
        accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        loss = tf.keras.losses.SparseCategoricalCrossentropy()
    
        network_fn.compile(loss=loss, metrics=[accuracy, metric_top_5])
        # eval_data: validation dataset. You can refer to ‘tf.keras.model.evaluate’ method to find out eval_data format and write data processing function to get your evaluation dataset. 
        eval_data = get_input_data()
        res = network_fn.evaluate(eval_data,
            steps=EVAL_NUM/batch_size,
            workers=16,
            verbose=1)
        delta_time = time.time() - start_time
        rescall5 = res[-1]
        eval_metric_ops = {'Recall_5': rescall5}
        return eval_metric_ops
    
  2. Export inference graph.
    import tensorflow as tf
    from tensorflow.keras import backend as K
    from tensorflow.python.framework import graph_util
    
    tf.keras.backend.set_learning_phase(0)
    model = tf.keras.applications.ResNet50(weights=None,
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())
    graph_def = K.get_session().graph.as_graph_def()
    graph_def = graph_util.extract_sub_graph(graph_def, ["probs/Softmax"])
    tf.train.write_graph(graph_def,
        "./models/ResNet50/train",
        "ResNet50_inf_graph.pbtxt",
        as_text=True)
  3. Convert weights from HDF5 to TensorFlow format.
    Note: Skip this step if the weights are already in the TensorFlow format.
    import tensorflow as tf
    
    tf.keras.backend.set_learning_phase(0)
    
    model = tf.keras.applications.ResNet50(weights="imagenet",
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    model.save_weights("./models/ResNet50/train/ResNet50.ckpt", save_format='tf')
    
  4. Run model analysis.
    vai_p_tensorflow \
        --action=ana \
        --input_graph=./models/ResNet50/train/ResNet50_inf_graph.pbtxt \
        --input_ckpt=./models/ResNet50/train/ResNet50.ckpt \
        --eval_fn_path=./ResNet50_model.py \
        --target=top-5 \
        --workspace=./ \
        --input_nodes="input_1" \
        --input_node_shapes="1,224,224,3" \
        --exclude="" \
        --output_nodes="probs/Softmax"
    
  5. Run model pruning.
    vai_p_tensorflow \
        --action=prune \
        --input_graph=./models/ResNet50/train/ResNet50_inf_graph.pbtxt \
        --input_ckpt=./models/ResNet50/train/ResNet50.ckpt \
        --output_graph=./models/ResNet50/pruned/graph.pbtxt \
        --output_ckpt=./models/ResNet50/pruned/sparse.ckpt \
        --workspace=./ \
        --input_nodes="input_1" \
        --input_node_shapes="1,224,224,3" \
        --exclude="" \
        --sparsity=0.5 \
        --output_nodes="probs/Softmax"
    
  6. Prepare model training code "train.py".
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os, time
    import tensorflow as tf
    import numpy as np
    
    from preprocessing import preprocessing_factory
    from preprocessing.dataset import input_fn, NUM_IMAGES
    
    TRAIN_NUM = NUM_IMAGES['train']
    EVAL_NUM = NUM_IMAGES['validation']
    
    tf.flags.DEFINE_string('model_name', 'ResNet50', 'The keras model name.')
    tf.flags.DEFINE_boolean('pruning', True, 'If running with pruning masks.')
    tf.flags.DEFINE_string('data_dir', '', 'The directory where put the evaluation tfrecord data.')
    tf.flags.DEFINE_string('checkpoint_path', './models/ResNet50/pruned/sparse.ckpt ', 'Model weights path from which to fine-tune.')
    tf.flags.DEFINE_string('train_dir', './models/ResNet50/pruned/ft', 'The directory where save model')
    tf.flags.DEFINE_string('ckpt_filename', "trained_model_{epoch}.ckpt", 'Model filename to be saved.')
    tf.flags.DEFINE_string('ft_ckpt', '', 'The model path to be saved from last epoch.')
    
    tf.flags.DEFINE_integer('batch_size', 100, 'Train batch size.')
    tf.flags.DEFINE_integer('train_image_size', 224, 'Train image size.')
    tf.flags.DEFINE_integer('epoches', 1, 'Train epochs')
    tf.flags.DEFINE_integer('eval_every_epoch', 1, '')
    tf.flags.DEFINE_integer('steps_per_epoch', None, 'How many steps one epoch contains.')
    tf.flags.DEFINE_float('learning_rate', 5e-3, 'Learning rate.')
    
    FLAGS = tf.flags.FLAGS
    
    def get_input_data(num_epochs=1, prefix_preprocessing="vgg"):
        train_data = input_fn(
            is_training=True, data_dir=FLAGS.data_dir,
            output_width=FLAGS.train_image_size,
            output_height=FLAGS.train_image_size,
            batch_size=FLAGS.batch_size,
            num_epochs=num_epochs,
            num_gpus=1,
            dtype=tf.float32,
            prefix_preprocessing=prefix_preprocessing)
    
        eval_data = input_fn(
            is_training=False, data_dir=FLAGS.data_dir,
            output_width=FLAGS.train_image_size,
            output_height=FLAGS.train_image_size,
            batch_size=FLAGS.batch_size,
            num_epochs=1,
            num_gpus=1,
            dtype=tf.float32,
            prefix_preprocessing=prefix_preprocessing)
        return train_data, eval_data
    
    tf.logging.info('Fine-tuning from %s' % FLAGS.checkpoint_path)
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.pruning:
        tf.set_pruning_mode()
    module_name = 'tf.keras.applications.' + FLAGS.model_name
    model = eval(module_name)(weights=None,
        include_top=True,
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000)
    os.makedirs(FLAGS.train_dir, exist_ok=True)
    
    def main():
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.5
        config.gpu_options.allow_growth = True
        prefix_preprocessing = preprocessing_factory.get_preprocessing_method(FLAGS.model_name)
        train_data, eval_data = get_input_data(num_epochs=FLAGS.epoches+1, prefix_preprocessing=prefix_preprocessing)
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=os.path.join(FLAGS.train_dir, FLAGS.ckpt_filename),
                save_best_only=True,
                save_weights_only=True,
                monitor="sparse_categorical_accuracy",
                verbose=1,
            )
        ]
        opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
        accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        loss = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(loss=loss, metrics=[accuracy, metric_top_5], optimizer=opt)
        model.load_weights(FLAGS.checkpoint_path)
    
        start = time.time()
        steps_per_epoch = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else np.ceil(TRAIN_NUM/FLAGS.batch_size)
        model.fit(train_data,
            epochs=FLAGS.epoches,
            callbacks=callbacks,
            steps_per_epoch=steps_per_epoch,
            # max_queue_size=16,
            workers=16)
        t_delta = round(1000*(time.time()-start), 2)
        print("Training {} epoch needs {}ms".format(FLAGS.epoches, t_delta))
        model.save_weights(FLAGS.ft_ckpt, save_format='tf')
        print('Finished training!')
    
    if __name__ == "__main__":
        main()
    
  7. Run model training code for fine-tuning the pruned model.
    python train.py –-pruning=True --checkpoint_path=./models/ResNet50/pruned/sparse.ckpt
    
  8. Transform sparse model to dense model.
    vai_p_tensorflow \
        --action=transform \
        --input_ckpt=./models/ResNet50/ft/trained_model_epoch.ckpt \
        --output_ckpt=./models/ ResNet50/pruned/transformed.ckpt
    
  9. Freeze graph.
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import sys
    
    from google.protobuf import text_format
    
    from tensorflow.core.framework import graph_pb2
    from tensorflow.core.protobuf import saver_pb2
    from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
    from tensorflow.python import pywrap_tensorflow
    from tensorflow.python.client import session
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import importer
    from tensorflow.python.platform import app
    from tensorflow.python.platform import gfile
    from tensorflow.python.saved_model import loader
    from tensorflow.python.saved_model import tag_constants
    from tensorflow.python.tools import saved_model_utils
    from tensorflow.python.training import saver as saver_lib
    
    def freeze_graph_with_def_protos(input_graph_def,
        input_saver_def,
        input_checkpoint,
        output_node_names,
        restore_op_name,
        filename_tensor_name,
        output_graph,
        clear_devices,
        initializer_nodes,
        variable_names_whitelist="",
        variable_names_blacklist="",
        input_meta_graph_def=None,
        input_saved_model_dir=None,
        saved_model_tags=None,
        checkpoint_version=saver_pb2.SaverDef.V2):
        """Converts all variables in a graph and checkpoint into constants."""
        del restore_op_name, filename_tensor_name # Unused by updated loading code.
    
        # 'input_checkpoint' may be a prefix if we're using Saver V2 format
        if (not input_saved_model_dir and
            not saver_lib.checkpoint_exists(input_checkpoint)):
            print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
            return -1
    
        if not output_node_names:
            print("You need to supply the name of a node to --output_node_names.")
            return -1
    
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            if input_meta_graph_def:
                for node in input_meta_graph_def.graph_def.node:
                    node.device = ""
            elif input_graph_def:
                for node in input_graph_def.node:
                    node.device = ""
        if input_graph_def:
            _ = importer.import_graph_def(input_graph_def, name="")
        with session.Session() as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def,
                                        write_version=checkpoint_version)
                saver.restore(sess, input_checkpoint)
            elif input_meta_graph_def:
                restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                       clear_devices=True)
                restorer.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes.replace(" ", "").split(","))
            elif input_saved_model_dir:
                if saved_model_tags is None:
                    saved_model_tags = []
                loader.load(sess, saved_model_tags, input_saved_model_dir)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ":0")
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example it's
                        # 'global_step' or a similar housekeeping element) so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes.replace(" ", "").split(","))
    
            variable_names_whitelist = (variable_names_whitelist.replace(
    " ", "").split(",") if variable_names_whitelist else None)
            variable_names_blacklist = (variable_names_blacklist.replace(
    " ", "").split(",") if variable_names_blacklist else None)
    
            if input_meta_graph_def:
                output_graph_def = graph_util.convert_variables_to_constants(
                    sess,
                    input_meta_graph_def.graph_def,
                    output_node_names.replace(" ", "").split(","),
                    variable_names_whitelist=variable_names_whitelist,
                    variable_names_blacklist=variable_names_blacklist)
            else:
                output_graph_def = graph_util.convert_variables_to_constants(
                    sess,
                    input_graph_def,
                    output_node_names.replace(" ", "").split(","),
                    variable_names_whitelist=variable_names_whitelist,
                    variable_names_blacklist=variable_names_blacklist)
    
        # Write GraphDef to file if output path has been given.
        if output_graph:
            with gfile.GFile(output_graph, "wb") as f:
                f.write(output_graph_def.SerializeToString())
    
        return output_graph_def
    
    def _parse_input_graph_proto(input_graph, input_binary):
        """Parser input tensorflow graph into GraphDef proto."""
        if not gfile.Exists(input_graph):
            print("Input graph file '" + input_graph + "' does not exist!")
            return -1
        input_graph_def = graph_pb2.GraphDef()
        mode = "rb" if input_binary else "r"
        with gfile.FastGFile(input_graph, mode) as f:
            if input_binary:
                input_graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), input_graph_def)
        return input_graph_def
    
    def _parse_input_meta_graph_proto(input_graph, input_binary):
        """Parser input tensorflow graph into MetaGraphDef proto."""
        if not gfile.Exists(input_graph):
            print("Input meta graph file '" + input_graph + "' does not exist!")
            return -1
        input_meta_graph_def = MetaGraphDef()
        mode = "rb" if input_binary else "r"
        with gfile.FastGFile(input_graph, mode) as f:
            if input_binary:
                input_meta_graph_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), input_meta_graph_def)
        print("Loaded meta graph file '" + input_graph)
        return input_meta_graph_def
    
    def _parse_input_saver_proto(input_saver, input_binary):
        """Parser input tensorflow Saver into SaverDef proto."""
        if not gfile.Exists(input_saver):
            print("Input saver file '" + input_saver + "' does not exist!")
            return -1
        mode = "rb" if input_binary else "r"
        with gfile.FastGFile(input_saver, mode) as f:
            saver_def = saver_pb2.SaverDef()
            if input_binary:
                saver_def.ParseFromString(f.read())
            else:
                text_format.Merge(f.read(), saver_def)
        return saver_def
    
    def freeze_graph(input_graph,
                     input_saver,
                     input_binary,
                     input_checkpoint,
                     output_node_names,
                     restore_op_name,
                     filename_tensor_name,
                     output_graph,
                     clear_devices,
                     initializer_nodes,
                     variable_names_whitelist="",
                     variable_names_blacklist="",
                     input_meta_graph=None,
                     input_saved_model_dir=None,
                     saved_model_tags=tag_constants.SERVING,
                     checkpoint_version=saver_pb2.SaverDef.V2):
        """Converts all variables in a graph and checkpoint into constants."""
        input_graph_def = None
        if input_saved_model_dir:
            input_graph_def = saved_model_utils.get_meta_graph_def(
                input_saved_model_dir, saved_model_tags).graph_def
        elif input_graph:
            input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
        input_meta_graph_def = None
        if input_meta_graph:
            input_meta_graph_def = _parse_input_meta_graph_proto(
                input_meta_graph, input_binary)
        input_saver_def = None
        if input_saver:
            input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
        freeze_graph_with_def_protos(input_graph_def,
                                     input_saver_def,
                                     input_checkpoint,
                                     output_node_names,
                                     restore_op_name,
                                     filename_tensor_name,
                                     output_graph,
                                     clear_devices,
                                     initializer_nodes,
                                     variable_names_whitelist,
                                     variable_names_blacklist,
                                     input_meta_graph_def,
                                     input_saved_model_dir,
                                     saved_model_tags.replace(" ", "").split(","),
                                     checkpoint_version=checkpoint_version)
    
    def main(unused_args, flags):
        if flags.checkpoint_version == 1:
            checkpoint_version = saver_pb2.SaverDef.V1
        elif flags.checkpoint_version == 2:
            checkpoint_version = saver_pb2.SaverDef.V2
        else:
            print("Invalid checkpoint version (must be '1' or '2'): %d" %
                   flags.checkpoint_version)
            return -1
        freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
                     flags.input_checkpoint, flags.output_node_names,
                     flags.restore_op_name, flags.filename_tensor_name,
                     flags.output_graph, flags.clear_devices, flags.initializer_nodes,
                     flags.variable_names_whitelist, flags.variable_names_blacklist,
                     flags.input_meta_graph, flags.input_saved_model_dir,
                     flags.saved_model_tags, checkpoint_version)
    
    def run_main():
        parser = argparse.ArgumentParser()
        parser.register("type", "bool", lambda v: v.lower() == "true")
        parser.add_argument("--input_graph",
                            type=str,
                            default="./models/ ResNet50/pruned/graph.pbtxt",
                            help="TensorFlow \'GraphDef\' file to load.")
        parser.add_argument("--input_saver",
                            type=str,
                            default="",
                            help="TensorFlow saver file to load.")
        parser.add_argument("--input_checkpoint",
                            type=str,
                            default="./models/ ResNet50/pruned/transformed.ckpt",
                            help="TensorFlow variables file to load.")
        parser.add_argument("--checkpoint_version",
                            type=int,
                            default=2,
                            help="Tensorflow variable file format")
        parser.add_argument("--output_graph",
                            type=str,
                            default="./models/ ResNet50/pruned/frozen_ResNet50.pb",
                            help="Output \'GraphDef\' file name.")
        parser.add_argument("--input_binary",
                            nargs="",
                            const=True,
                            type="bool",
                            default=False,
                            help="Whether the input files are in binary format.")
        parser.add_argument("--output_node_names",
                            type=str,
                            default="probs/Softmax",
                            help="The name of the output nodes, comma separated.")
        parser.add_argument("--restore_op_name",
                            type=str,
                            default="save/restore_all",
                            help="""\
                The name of the master restore operator. Deprecated, unused by updated \
                loading code.
                """)
        parser.add_argument("--filename_tensor_name",
                            type=str,
                            default="save/Const:0",
                            help="""\
                The name of the tensor holding the save path. Deprecated, unused by \
                updated loading code.
                """)
        parser.add_argument("--clear_devices",
                            nargs="",
                            const=True,
                            type="bool",
                            default=True,
                            help="Whether to remove device specifications.")
        parser.add_argument(
                            "--initializer_nodes",
                            type=str,
                            default="",
                            help="Comma separated list of initializer nodes to run before freezing.")
        parser.add_argument("--variable_names_whitelist",
                            type=str,
                            default="",
                            help="""\
                Comma separated list of variables to convert to constants. If specified, \
                only those variables will be converted to constants.\
                """)
        parser.add_argument("--variable_names_blacklist",
                            type=str,
                            default="",
                            help="""\
                Comma separated list of variables to skip converting to constants.\
                """)
        parser.add_argument("--input_meta_graph",
                            type=str,
                            default="",
                            help="TensorFlow \'MetaGraphDef\' file to load.")
        parser.add_argument(
                            "--input_saved_model_dir",
                            type=str,
                            default="",
                            help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
        parser.add_argument("--saved_model_tags",
                            type=str,
                            default="serve",
                            help="""\
                Group of tag(s) of the MetaGraphDef to load, in string format,\
                separated by \',\'. For tag-set contains multiple tags, all tags \
                must be passed in.\
                """)
        flags, unparsed = parser.parse_known_args()
    
        my_main = lambda unused_args: main(unused_args, flags)
        app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
    
    if __name__ == '__main__':
        run_main()