顶级问题
我想使用教师网络并将其性能/知识从其功能的一小部分提炼到另一个更简单的模型
尝试的解决方案
我正在尝试开始使用 T2T 蒸馏代码。 https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/distillation.py
尝试解决的问题
我正在努力了解如何为我的老师和我的学生使用它。有没有什么例子可以说明我可以先运行它是如何工作的?如何让它在 T2T 中的现有模型上工作?如何让它在 Keras 中定义的模型上工作?
我理解这个块是我应该用来在 T2T 中注册我的模型的。但是我该如何开始训练呢?我在 github 上搜索distill_resnet_32_to_15_cifar20x5
,唯一的命中是 T2T 回购的重复分叉,没有如何使用它的示例。
@registry.register_hparams
def distill_resnet_32_to_15_cifar20x5():
"""Set of hyperparameters."""
hparams = distill_base()
hparams.teacher_model = "resnet"
hparams.teacher_hparams = "resnet_cifar_32"
hparams.student_model = "resnet"
hparams.student_hparams = "resnet_cifar_15"
hparams.optimizer_momentum_nesterov = True
# (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
hparams.teacher_learning_rate = 0.25 * 128. * 8. / 256.
hparams.student_learning_rate = 0.2 * 128. * 8. / 256.
hparams.learning_rate_decay_scheme = "piecewise"
hparams.add_hparam("learning_rate_boundaries", [40000, 60000, 80000])
hparams.add_hparam("learning_rate_multiples", [0.1, 0.01, 0.001])
hparams.task_balance = 0.28
hparams.distill_temperature = 2.0
hparams.num_classes = 20
return hparams
完整代码:distillation.py
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Traditional Student-Teacher Distillation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_hparams
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model
import tensorflow as tf
@registry.register_model
class Distillation(t2t_model.T2TModel):
"""Distillation from a teacher to student network.
First, a teacher is trained on a task; Second, a student is trained to perform
the task while matching the teacher's softened outputs. For more details, see
the paper below.
In the hparams passed to this model include the desired
{teacher/student}_model and {teacher/student}_hparams to be used. Also,
specify the distillation temperature and task-distillation balance.
Distilling the Knowledge in a Neural Network
Hinton, Vinyals and Dean
https://arxiv.org/abs/1503.02531
"""
def __init__(self,
hparams,
mode=tf.estimator.ModeKeys.TRAIN,
problem_hparams=None,
data_parallelism=None,
decode_hparams=None):
assert hparams.distill_phase in ["train", "distill"]
if hparams.distill_phase == "train" and hparams.teacher_learning_rate:
hparams.learning_rate = hparams.teacher_learning_rate
elif hparams.distill_phase == "distill" and hparams.student_learning_rate:
hparams.learning_rate = hparams.student_learning_rate
self.teacher_hparams = registry.hparams(hparams.teacher_hparams)
self.teacher_model = registry.model(
hparams.teacher_model)(self.teacher_hparams, mode, problem_hparams,
data_parallelism, decode_hparams)
self.student_hparams = registry.hparams(hparams.student_hparams)
self.student_model = registry.model(
hparams.student_model)(self.student_hparams, mode, problem_hparams,
data_parallelism, decode_hparams)
super(Distillation, self).__init__(hparams, mode, problem_hparams,
data_parallelism, decode_hparams)
def body(self, features):
hp = self.hparams
is_distill = hp.distill_phase == "distill"
targets = features["targets_raw"]
targets = tf.squeeze(targets, [1, 2, 3])
one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32)
# Teacher Network
with tf.variable_scope("teacher"):
teacher_outputs = self.teacher_model.body(features)
tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape())
teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2])
teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes)
teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_targets, logits=teacher_logits)
outputs = teacher_logits
if is_distill:
# Load teacher weights
tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"})
# Do not train the teacher
trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
del trainable_vars[:]
# Student Network
if is_distill:
with tf.variable_scope("student"):
student_outputs = self.student_model.body(features)
tf.logging.info(
"student output shape: %s" % student_outputs.get_shape())
student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2])
student_logits = tf.layers.dense(student_outputs, hp.num_classes)
student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=one_hot_targets, logits=student_logits)
teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature)
student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.stop_gradient(teacher_targets), logits=student_logits)
outputs = student_logits
# Summaries
tf.summary.scalar("distill_xent", student_distill_xent)
if not is_distill:
phase_loss = teacher_task_xent
else:
phase_loss = hp.task_balance * student_task_xent
phase_loss += (1 - hp.task_balance) * student_distill_xent
losses = {"training": phase_loss}
outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]])
return outputs, losses
def top(self, body_output, features):
return body_output
def distill_base():
"""Set of hyperparameters."""
# Base
hparams = common_hparams.basic_params1()
# teacher/student parameters
hparams.add_hparam("teacher_model", "")
hparams.add_hparam("teacher_hparams", "")
hparams.add_hparam("student_model", "")
hparams.add_hparam("student_hparams", "")
# Distillation parameters
# WARNING: distill_phase hparam will be overwritten in /bin/t2t_distill.py
hparams.add_hparam("distill_phase", None)
hparams.add_hparam("task_balance", 1.0)
hparams.add_hparam("distill_temperature", 1.0)
hparams.add_hparam("num_classes", 10)
# Optional Phase-specific hyperparameters
hparams.add_hparam("teacher_learning_rate", None)
hparams.add_hparam("student_learning_rate", None)
# Training parameters (stolen from ResNet)
hparams.batch_size = 128
hparams.optimizer = "Momentum"
hparams.optimizer_momentum_momentum = 0.9
hparams.optimizer_momentum_nesterov = True
hparams.weight_decay = 1e-4
hparams.clip_grad_norm = 0.0
# (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
hparams.learning_rate = 0.4
hparams.learning_rate_decay_scheme = "cosine"
# For image_imagenet224, 120k training steps, which effectively makes this a
# cosine decay (i.e. no cycles).
hparams.learning_rate_cosine_cycle_steps = 120000
hparams.initializer = "normal_unit_scaling"
hparams.initializer_gain = 2.
return hparams
@registry.register_hparams
def distill_resnet_32_to_15_cifar20x5():
"""Set of hyperparameters."""
hparams = distill_base()
hparams.teacher_model = "resnet"
hparams.teacher_hparams = "resnet_cifar_32"
hparams.student_model = "resnet"
hparams.student_hparams = "resnet_cifar_15"
hparams.optimizer_momentum_nesterov = True
# (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
hparams.teacher_learning_rate = 0.25 * 128. * 8. / 256.
hparams.student_learning_rate = 0.2 * 128. * 8. / 256.
hparams.learning_rate_decay_scheme = "piecewise"
hparams.add_hparam("learning_rate_boundaries", [40000, 60000, 80000])
hparams.add_hparam("learning_rate_multiples", [0.1, 0.01, 0.001])
hparams.task_balance = 0.28
hparams.distill_temperature = 2.0
hparams.num_classes = 20
return hparams