1

在文本示例上训练 cnn 模型时遇到问题。

  1. 在单 GPU 上训练时,我的 GPU 利用率非常高,大约 97%,但训练速度非常慢。1000 个批次需要 450 秒(每批次 64 个示例),因此每个示例需要 7 毫秒。相比之下,分层 lstm 每个示例只需要 2~3ms。
  2. 我试图在 GPU 集群上部署我的训练进度,但得到了一个奇怪的 GPU 利用率。我使用了 4 个 GPU,大部分时间利用率为 0%。我尝试将批量大小从 64 修改为 2,然后 GPU 利用率变为正常,但小批量会导致性能低下。所以我想问有没有一种有效的方法可以通过使用 GPU 集群来加快训练进度。

(顺便说一句,这些问题仅在单个输入示例非常大时发生,例如包含数千个单词的新闻正文内容。当输入是新闻标题时,GPU 集群工作正常)

输入格式 [ 64 (examples/batch) * 2500 (words/example) * 200 (embedding dim) ] 对于我的 5 层 cnn 模型来说是否太大而无法正确训练?

1.模型定义(改编自https://github.com/dennybritz/cnn-text-classification-tf

import tensorflow as tf
import numpy as np

class TextCNN(object):
    """
    A CNN for text classification.
    Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer.
    """

    # sequence_length : 2500 (words per doc)
    # num_classes : 36
    # vocab_size : 500,000
    # embedding size : 200
    # filter_sizes : [25, 50, 100]
    # num_filters : [32, 64, 128]
    def __init__(
      self, sequence_length, num_classes, vocab_size,
      embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):

        # Placeholders for input, output and dropout
        self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
        self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y")
        self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")

        # Keeping track of l2 regularization loss (optional)
        l2_loss = tf.constant(0.0)

        # Embedding layer
        self.W = tf.Variable(
            tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
            name="W")
        self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
        self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

        # Create a convolution + maxpool layer for each filter size
        pooled_outputs = []


        def conv2d(x, W, stride_h, stride_w):
            return tf.nn.conv2d(x, W, strides=[1,stride_h,stride_w,1], padding="VALID")

        def max_pool(x, h, w):
            return tf.nn.max_pool(x, ksize=[1,h,w,1], strides=[1,h,w,1], padding="VALID")

        def weight_variable(shape):
            initial = tf.truncated_normal(shape, stddev=0.1)
            return tf.Variable(initial)

        def bias_variable(shape):
            initial = tf.constant(0.1, shape=shape)
            return tf.Variable(initial)

        n_conv1 = num_filters[0]    # 32
        n_conv2 = num_filters[1]    # 64
        n_conv3 = num_filters[2]    # 128
        n_fc1 = 200
        n_fc2 = 200

        # filter_sizes : [25, 50, 100]
        for i, filter_size in enumerate(filter_sizes):
            with tf.name_scope("conv-maxpool-%s" % filter_size):
                # Convolution Layer
                print '######## conv-maxpool-%s ########', filter_size
                w_conv1 = weight_variable([filter_size, 40, 1, n_conv1])
                b_conv1 = bias_variable([n_conv1])
                f_conv1 = tf.nn.relu(conv2d(self.embedded_chars_expanded, w_conv1, 25, 5) + b_conv1)
                print 'conv1: ', f_conv1
                f_pool1 = max_pool(f_conv1, 2, 2)
                print 'pool1: ', f_pool1

                w_conv2 = weight_variable([3, 3, n_conv1, n_conv2])
                b_conv2 = bias_variable([n_conv2])
                f_conv2 = tf.nn.relu(conv2d(f_pool1, w_conv2, 2, 1) + b_conv2)
                print 'conv2: ', f_conv2
                f_pool2 = max_pool(f_conv2, 2, 2)
                print 'pool2: ', f_pool2

                w_conv3 = weight_variable([2, 2, n_conv2, n_conv3])
                b_conv3 = bias_variable([n_conv3])
                f_conv3 = tf.nn.relu(conv2d(f_pool2, w_conv3, 1, 1) + b_conv3)
                print 'conv3: ', f_conv3
                f_pool3 = max_pool(f_conv3, 2, 2)
                print 'pool3: ', f_pool3

                f_size_conv3 = 5 * 3
                f_pool3_flat = tf.reshape(f_pool3, [-1, f_size_conv3 * n_conv3])

                w_fc1 = weight_variable([f_size_conv3 * n_conv3, n_fc1])
                b_fc1 = bias_variable([n_fc1])
                f_fc1 = tf.nn.sigmoid(tf.matmul(f_pool3_flat, w_fc1) + b_fc1)
                print 'f_fc1: ', f_fc1
                pooled_outputs.append(f_fc1)

        i_fc2 = tf.concat(pooled_outputs, 1)
        print i_fc2
        w_fc2 = weight_variable([n_fc1*len(filter_sizes), n_fc2])
        b_fc2 = bias_variable([n_fc2])
        f_fc2 = tf.nn.sigmoid(tf.matmul(i_fc2, w_fc2) + b_fc2)
        print 'f_fc2: ', f_fc2
        # Combine all the pooled features
        num_filters_total = n_fc2
        self.h_pool_flat = tf.reshape(f_fc2, [-1, num_filters_total])

        # Add dropout
        with tf.name_scope("dropout"):
            self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)

        # Final (unnormalized) scores and predictions
        with tf.name_scope("output"):
            W = tf.get_variable(
                "W",
                shape=[num_filters_total, num_classes],
                initializer=tf.contrib.layers.xavier_initializer())
            b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
            l2_loss += tf.nn.l2_loss(W)
            l2_loss += tf.nn.l2_loss(b)
            self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores")
            self.predictions = tf.argmax(self.scores, 1, name="predictions")

        # CalculateMean cross-entropy loss
        with tf.name_scope("loss"):
            losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores+1e-10, labels=self.input_y)
            self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss

        # Accuracy
        with tf.name_scope("accuracy"):
            self.correct_predition = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
            self.correct_num = tf.reduce_sum(tf.cast(self.correct_predition, tf.float32))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_predition, "float"), name="accuracy")

2.训练进度(单GPU)

import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_loader_cnn as data_loader
from tensorflow.contrib import learn

import sys
sys.path.append('./model_def')
from cnn_model import TextCNN

# Data loading params
tf.flags.DEFINE_string("train_path", "/data/train_data.idx", "Data source for the positive data.")
tf.flags.DEFINE_string("valid_path", "/data/valid_data.idx", "Data source for the validation data.")
tf.flags.DEFINE_string("ckpt_dir", "runs-cnn", "Directory for checkpoints.")
tf.flags.DEFINE_integer("class_num", 36, "Number of total classes")
tf.flags.DEFINE_integer("vocab_size", 500000, "Number of total distinct words")
tf.flags.DEFINE_integer("document_length", 50, "Max number of sentences in single text")
tf.flags.DEFINE_integer("sentence_length", 50, "Max number of words in single sentence")

# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 200, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "25,50,100", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_string("num_filters", "32,64,128", "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")
tf.flags.DEFINE_float("lr", 0.1, "Learning rate (default: 0.1)")
tf.flags.DEFINE_float("lr_decay", 0.5, "Learning rate decay per epoch (default: 0.6)")
tf.flags.DEFINE_integer("max_decay_epoch", 10, "Max epoch before decay lr (default: 30)")
tf.flags.DEFINE_integer('max_grad_norm', 5, 'max_grad_norm')

# Training parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 60, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 1000, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 50000, "Save model after this many steps (default: 100)")

# Misc Parameters
tf.flags.DEFINE_boolean("allow_growth", True, "Allow memory softly growth")
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")

# For distributed
tf.flags.DEFINE_string("ps_hosts", "",
                       "Comma-separated list of hostname:port pairs")
tf.flags.DEFINE_string("worker_hosts", "",
                       "Comma-separated list of hostname:port pairs")
tf.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.flags.DEFINE_integer("issync", 0, "1 for sync and 0 for async")

FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


# Training
# ==================================================
def main(_):
    with tf.device('/gpu:1'):
        gpu_config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
        gpu_config.gpu_options.allow_growth = FLAGS.allow_growth
        with tf.Graph().as_default(), tf.Session(config=gpu_config) as sess:
            # Load data
            print("Loading data...")
            document_length = FLAGS.document_length
            sentence_length = FLAGS.sentence_length
            train_data = data_loader.load_data(FLAGS.train_path, document_length, sentence_length, FLAGS.class_num)
            valid_data = data_loader.load_data(FLAGS.valid_path, document_length, sentence_length, FLAGS.class_num)
            batch_num_per_epoch = len(train_data[0]) / FLAGS.batch_size

            print 'len train data', len(train_data[0])
            print 'batch_num_per_epoch', batch_num_per_epoch

            # Building model
            cnn = TextCNN(
                sequence_length=document_length*sentence_length,
                num_classes=FLAGS.class_num,
                vocab_size=FLAGS.vocab_size,
                embedding_size=FLAGS.embedding_dim,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                num_filters=list(map(int, FLAGS.num_filters.split(","))),
                l2_reg_lambda=FLAGS.l2_reg_lambda)

            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            lr = tf.Variable(0.0, trainable=False)
            new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
            _lr_update = tf.assign(lr, new_lr)

            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(cnn.loss, tvars),
                                          FLAGS.max_grad_norm)
            optimizer = tf.train.GradientDescentOptimizer(lr)
            grads_and_vars = zip(grads, tvars)
            optimizer.apply_gradients(grads_and_vars)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, FLAGS.ckpt_dir, timestamp))
            print("Writing to {}\n".format(out_dir))

            def train_step(sess, x_batch, y_batch, epoch_index):
                """
                A single training step
                """

                # hard coding
                if epoch_index < 15:
                    new_lr_temp = 0.1
                if epoch_index >= 15 and epoch_index < 25:
                    new_lr_temp = 0.01
                elif epoch_index >= 25 and epoch_index < 40:
                    new_lr_temp = 0.001
                elif epoch_index >= 40:
                    new_lr_temp = 0.0001

                feed_dict = {
                  cnn.input_x: x_batch,
                  cnn.input_y: y_batch,
                  cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                  new_lr: new_lr_temp
                }
                current_lr, _, _, loss, accuracy = sess.run(
                    [lr, _lr_update, train_op, cnn.loss, cnn.accuracy],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                return loss, time_str


            # ====================== dev_step ======================
            def dev_step(sess, x_batch, y_batch, writer=None):
                """
                Evaluates model on a dev set
                """
                exp = int(max(epoch_index-FLAGS.max_decay_epoch,0)/20)
                lr_decay = FLAGS.lr_decay ** exp

                feed_dict = {
                  cnn.input_x: x_batch,
                  cnn.input_y: y_batch,
                  cnn.dropout_keep_prob: 1.0,
                  new_lr: FLAGS.lr*lr_decay
                }
                _, loss, correct_num = sess.run(
                    [_lr_update, cnn.loss, cnn.correct_num],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                return len(x_batch), correct_num


            # ====================== eval ======================
            def evaluate(sess, valid_data, batch_size):
                batch_iter = data_loader.batch_iter(valid_data, batch_size)
                example_num = 0
                correct_num = 0
                for valid_x, valid_y in batch_iter:
                    batch_len, batch_corrent = dev_step(sess, valid_x, valid_y)
                    example_num += batch_len
                    correct_num += batch_corrent
                accuracy = float(correct_num) / example_num
                return accuracy


            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            init_op = tf.initialize_all_variables()
            saver = tf.train.Saver(tf.all_variables())
            sess.run(init_op)

            # Generate batches
            batch_iter = data_loader.global_batch_iter(
                train_data, FLAGS.batch_size, FLAGS.num_epochs)
            # Training loop. For each batch...
            current_step = sess.run(global_step)
            print 'current step', current_step
            while current_step < batch_num_per_epoch * FLAGS.num_epochs:
                current_step = sess.run(global_step)
                epoch_index = current_step / batch_num_per_epoch
                if current_step % batch_num_per_epoch == 0:
                    print("Epoch ", epoch_index)

                x_batch, y_batch = next(batch_iter)
                loss, time_str = train_step(sess, x_batch, y_batch, epoch_index)

                if current_step % FLAGS.evaluate_every == 0:
                    accuracy = evaluate(sess, valid_data, FLAGS.batch_size)
                    print("{}: step {}, loss {:g}, acc {:g}".format(time_str, current_step, loss, accuracy))

                if current_step % FLAGS.checkpoint_every == 0:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    print("Saved model checkpoint to {}\n".format(path))


if __name__ == "__main__":
  tf.app.run()

3.并行训练进度

import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_loader_cnn as data_loader
from cnn_model import TextCNN
from tensorflow.contrib import learn

# Data loading params
tf.flags.DEFINE_string("train_path", "/data/slice/", "Data source for the positive data.")
tf.flags.DEFINE_string("valid_path", "/data/valid_data.idx", "Data source for the validation data.")
tf.flags.DEFINE_string("ckpt_dir", "runs-cnn", "Directory for checkpoints.")
tf.flags.DEFINE_integer("class_num", 36, "Number of total classes")
tf.flags.DEFINE_integer("vocab_size", 500000, "Number of total distinct words")
tf.flags.DEFINE_integer("document_length", 50, "Max number of sentences in single text")
tf.flags.DEFINE_integer("sentence_length", 50, "Max number of words in single sentence")

# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 200, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "25,50,100", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_string("num_filters", "32,64,128", "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")
tf.flags.DEFINE_float("lr", 0.1, "Learning rate (default: 0.1)")
tf.flags.DEFINE_float("lr_decay", 0.5, "Learning rate decay per epoch (default: 0.6)")
tf.flags.DEFINE_integer("max_decay_epoch", 10, "Max epoch before decay lr (default: 30)")
tf.flags.DEFINE_integer('max_grad_norm', 5, 'max_grad_norm')

# Training parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 60, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 1000, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 50000, "Save model after this many steps (default: 100)")

# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")

# For distributed
tf.flags.DEFINE_string("ps_hosts", "",
                       "Comma-separated list of hostname:port pairs")
tf.flags.DEFINE_string("worker_hosts", "",
                       "Comma-separated list of hostname:port pairs")
tf.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.flags.DEFINE_integer("issync", 0, "1 for sync and 0 for async")

FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


# Training
# ==================================================
def main(_):
    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
    issync = FLAGS.issync
    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        with tf.device(tf.train.replica_device_setter(
                          worker_device="/job:worker/task:%d" % FLAGS.task_index,
                          cluster=cluster)):
            # Load data
            print("Loading data...")
            document_length = FLAGS.document_length
            sentence_length = FLAGS.sentence_length
            train_path = FLAGS.train_path + str(FLAGS.task_index)
            train_data = data_loader.load_data(train_path, document_length, sentence_length, FLAGS.class_num)
            valid_data = data_loader.load_data(FLAGS.valid_path, document_length, sentence_length, FLAGS.class_num)
            batch_num_per_epoch = len(train_data[0]) / FLAGS.batch_size

            print 'len train data', len(train_data[0])
            print 'batch_num_per_epoch', batch_num_per_epoch

            # Building model
            cnn = TextCNN(
                sequence_length=document_length*sentence_length,
                num_classes=FLAGS.class_num,
                vocab_size=FLAGS.vocab_size,
                embedding_size=FLAGS.embedding_dim,
                filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                num_filters=list(map(int, FLAGS.num_filters.split(","))),
                l2_reg_lambda=FLAGS.l2_reg_lambda)

            # Define Training procedure
            global_step = tf.Variable(0, name="global_step", trainable=False)
            lr = tf.Variable(0.0, trainable=False)
            new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
            _lr_update = tf.assign(lr, new_lr)

            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(cnn.loss, tvars),
                                          FLAGS.max_grad_norm)
            optimizer = tf.train.GradientDescentOptimizer(lr)
            grads_and_vars = zip(grads, tvars)
            optimizer.apply_gradients(grads_and_vars)
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

            # Output directory for models and summaries
            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, FLAGS.ckpt_dir, timestamp))
            print("Writing to {}\n".format(out_dir))

            def train_step(sess, x_batch, y_batch, epoch_index):
                """
                A single training step
                """

                # hard coding
                if epoch_index < 15:
                    new_lr_temp = 0.1
                if epoch_index >= 15 and epoch_index < 25:
                    new_lr_temp = 0.01
                elif epoch_index >= 25 and epoch_index < 40:
                    new_lr_temp = 0.001
                elif epoch_index >= 40:
                    new_lr_temp = 0.0001

                feed_dict = {
                  cnn.input_x: x_batch,
                  cnn.input_y: y_batch,
                  cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                  new_lr: new_lr_temp
                }
                current_lr, _, _, loss, accuracy = sess.run(
                    [lr, _lr_update, train_op, cnn.loss, cnn.accuracy],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                return loss, time_str


            # ====================== dev_step ======================
            def dev_step(sess, x_batch, y_batch, writer=None):
                """
                Evaluates model on a dev set
                """
                exp = int(max(epoch_index-FLAGS.max_decay_epoch,0)/20)
                lr_decay = FLAGS.lr_decay ** exp

                feed_dict = {
                  cnn.input_x: x_batch,
                  cnn.input_y: y_batch,
                  cnn.dropout_keep_prob: 1.0,
                  new_lr: FLAGS.lr*lr_decay
                }
                _, loss, correct_num = sess.run(
                    [_lr_update, cnn.loss, cnn.correct_num],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                return len(x_batch), correct_num


            # ====================== eval ======================
            def evaluate(sess, valid_data, batch_size):
                batch_iter = data_loader.batch_iter(valid_data, batch_size)
                example_num = 0
                correct_num = 0
                for valid_x, valid_y in batch_iter:
                    batch_len, batch_corrent = dev_step(sess, valid_x, valid_y)
                    example_num += batch_len
                    correct_num += batch_corrent
                accuracy = float(correct_num) / example_num
                return accuracy


            # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            checkpoint_prefix = os.path.join(checkpoint_dir, "model")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            init_op = tf.initialize_all_variables()
            saver = tf.train.Saver(tf.all_variables())

            ################################################################
            sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                    logdir=checkpoint_prefix,
                                    init_op=init_op,
                                    summary_op=None,
                                    saver=saver,
                                    global_step=global_step,
                                    save_model_secs=60)

            with sv.prepare_or_wait_for_session(server.target) as sess:
                # Generate batches
                batch_iter = data_loader.global_batch_iter(
                    train_data, FLAGS.batch_size, FLAGS.num_epochs)
                # Training loop. For each batch...
                current_step = sess.run(global_step)
                while current_step < batch_num_per_epoch * FLAGS.num_epochs:
                    current_step = sess.run(global_step)
                    epoch_index = current_step / batch_num_per_epoch
                    if current_step % batch_num_per_epoch == 0:
                        print("Epoch ", epoch_index)

                    x_batch, y_batch = next(batch_iter)
                    loss, time_str = train_step(sess, x_batch, y_batch, epoch_index)

                    if current_step % FLAGS.evaluate_every == 0:
                        accuracy = evaluate(sess, valid_data, FLAGS.batch_size)
                        print("{}: step {}, loss {:g}, acc {:g}".format(time_str, current_step, loss, accuracy))

                    if current_step % FLAGS.checkpoint_every == 0:
                        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))
            sv.stop()


if __name__ == "__main__":
  tf.app.run()
4

0 回答 0