|
| 1 | +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Cifar example using Keras for model definition.""" |
| 16 | + |
| 17 | +from __future__ import absolute_import |
| 18 | +from __future__ import division |
| 19 | +from __future__ import print_function |
| 20 | + |
| 21 | +import tensorflow as tf |
| 22 | + |
| 23 | +from tensorflow.contrib.tpu.python.tpu import tpu_config |
| 24 | +from tensorflow.contrib.tpu.python.tpu import tpu_estimator |
| 25 | +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer |
| 26 | + |
| 27 | +tf.flags.DEFINE_integer("batch_size", 128, |
| 28 | + "Mini-batch size for the computation. Note that this " |
| 29 | + "is the global batch size and not the per-shard batch.") |
| 30 | +tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.") |
| 31 | +tf.flags.DEFINE_string("train_file", "", "Path to cifar10 training data.") |
| 32 | +tf.flags.DEFINE_integer("train_steps", 100000, |
| 33 | + "Total number of steps. Note that the actual number of " |
| 34 | + "steps is the next multiple of --iterations greater " |
| 35 | + "than this value.") |
| 36 | +tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs") |
| 37 | +tf.flags.DEFINE_string("master", "", |
| 38 | + "BNS name of the TensorFlow master to use.") |
| 39 | +tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir") |
| 40 | +tf.flags.DEFINE_integer("iterations", 100, |
| 41 | + "Number of iterations per TPU training loop.") |
| 42 | +tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).") |
| 43 | + |
| 44 | + |
| 45 | +FLAGS = tf.flags.FLAGS |
| 46 | + |
| 47 | + |
| 48 | +def model_fn(features, labels, mode, params): |
| 49 | + """Define a CIFAR model in Keras.""" |
| 50 | + del params # unused |
| 51 | + layers = tf.contrib.keras.layers |
| 52 | + |
| 53 | + # Pass our input tensor to initialize the Keras input layer. |
| 54 | + v = layers.Input(tensor=features) |
| 55 | + v = layers.Conv2D(filters=32, kernel_size=5, |
| 56 | + activation="relu", padding="same")(v) |
| 57 | + v = layers.MaxPool2D(pool_size=2)(v) |
| 58 | + v = layers.Conv2D(filters=64, kernel_size=5, |
| 59 | + activation="relu", padding="same")(v) |
| 60 | + v = layers.MaxPool2D(pool_size=2)(v) |
| 61 | + v = layers.Flatten()(v) |
| 62 | + fc1 = layers.Dense(units=512, activation="relu")(v) |
| 63 | + logits = layers.Dense(units=10)(fc1) |
| 64 | + |
| 65 | + # Instead of constructing a Keras model for training, build our loss function |
| 66 | + # and optimizer in Tensorflow. |
| 67 | + # |
| 68 | + # N.B. This construction omits some features that are important for more |
| 69 | + # complex models (e.g. regularization, batch-norm). Once |
| 70 | + # `model_to_estimator` support is added for TPUs, it should be used instead. |
| 71 | + loss = tf.reduce_mean( |
| 72 | + tf.nn.sparse_softmax_cross_entropy_with_logits( |
| 73 | + logits=logits, labels=labels |
| 74 | + ) |
| 75 | + ) |
| 76 | + optimizer = tf.train.AdamOptimizer() |
| 77 | + if FLAGS.use_tpu: |
| 78 | + optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) |
| 79 | + |
| 80 | + train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) |
| 81 | + |
| 82 | + return tpu_estimator.TPUEstimatorSpec( |
| 83 | + mode=mode, |
| 84 | + loss=loss, |
| 85 | + train_op=train_op, |
| 86 | + predictions={ |
| 87 | + "classes": tf.argmax(input=logits, axis=1), |
| 88 | + "probabilities": tf.nn.softmax(logits, name="softmax_tensor") |
| 89 | + } |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +def input_fn(params): |
| 94 | + """Read CIFAR input data from a TFRecord dataset.""" |
| 95 | + del params |
| 96 | + batch_size = FLAGS.batch_size |
| 97 | + def parser(serialized_example): |
| 98 | + """Parses a single tf.Example into image and label tensors.""" |
| 99 | + features = tf.parse_single_example( |
| 100 | + serialized_example, |
| 101 | + features={ |
| 102 | + "image": tf.FixedLenFeature([], tf.string), |
| 103 | + "label": tf.FixedLenFeature([], tf.int64), |
| 104 | + }) |
| 105 | + image = tf.decode_raw(features["image"], tf.uint8) |
| 106 | + image.set_shape([3*32*32]) |
| 107 | + image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 |
| 108 | + image = tf.transpose(tf.reshape(image, [3, 32, 32])) |
| 109 | + label = tf.cast(features["label"], tf.int32) |
| 110 | + return image, label |
| 111 | + |
| 112 | + dataset = tf.data.TFRecordDataset([FLAGS.train_file]) |
| 113 | + dataset = dataset.map(parser, num_parallel_calls=batch_size) |
| 114 | + dataset = dataset.prefetch(4 * batch_size).cache().repeat() |
| 115 | + dataset = dataset.apply( |
| 116 | + tf.contrib.data.batch_and_drop_remainder(FLAGS.batch_size) |
| 117 | + ) |
| 118 | + dataset = dataset.prefetch(1) |
| 119 | + images, labels = dataset.make_one_shot_iterator().get_next() |
| 120 | + return images, labels |
| 121 | + |
| 122 | + |
| 123 | +def main(argv): |
| 124 | + del argv # Unused. |
| 125 | + |
| 126 | + run_config = tpu_config.RunConfig( |
| 127 | + master=FLAGS.master, |
| 128 | + model_dir=FLAGS.model_dir, |
| 129 | + save_checkpoints_secs=3600, |
| 130 | + session_config=tf.ConfigProto( |
| 131 | + allow_soft_placement=True, log_device_placement=True), |
| 132 | + tpu_config=tpu_config.TPUConfig( |
| 133 | + iterations_per_loop=FLAGS.iterations, num_shards=FLAGS.num_shards), |
| 134 | + ) |
| 135 | + |
| 136 | + estimator = tpu_estimator.TPUEstimator( |
| 137 | + model_fn=model_fn, |
| 138 | + use_tpu=FLAGS.use_tpu, |
| 139 | + config=run_config, |
| 140 | + train_batch_size=FLAGS.batch_size) |
| 141 | + estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps) |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 146 | + tf.app.run(main) |
0 commit comments