0

我在 Colab 笔记本“Welcome to the Tensor2Tensor Colab”中尝试了英语到德语的翻译,效果很好。但是我必须错过代码中的某些内容才能使其适用于德语到英语。

根据以下页面https://github.com/tensorflow/tensor2tensor,我添加了“_rev”以“反转”翻译。与原始笔记本相比的两个更改使用 '# <-------------' 标记:

# Fetch the problem
ende_problem = problems.problem("translate_ende_wmt32k_rev") # <------------- 

# Copy the vocab file locally so we can encode inputs and decode model outputs
# All vocabs are stored on GCS
vocab_name = "vocab.translate_ende_wmt32k.32768.subwords"
vocab_file = os.path.join(gs_data_dir, vocab_name)
!gsutil cp {vocab_file} {data_dir}

# Get the encoders from the problem
encoders = ende_problem.feature_encoders(data_dir)

# Setup helper functions for encoding and decoding
def encode(input_str, output_str=None):
  """Input str to features dict, ready for inference"""
  inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
  batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
  return {"inputs": batch_inputs}

def decode(integers):
  """List of ints to str"""
  integers = list(np.squeeze(integers))
  if 1 in integers:
    integers = integers[:integers.index(1)]
  return encoders["inputs"].decode(np.squeeze(integers))

#Create hparams and the model
model_name = "transformer"
hparams_set = "transformer_base"

hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name="translate_ende_wmt32k_rev") # <-------------

# NOTE: Only create the model once when restoring from a checkpoint; it's a
# Layer and so subsequent instantiations will have different variable scopes
# that will not match the checkpoint.
translate_model = registry.model(model_name)(hparams, Modes.EVAL)

# Copy the pretrained checkpoint locally
ckpt_name = "transformer_ende_test"
gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)
!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}
ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))
ckpt_path

# Restore and translate!
def translate(inputs):
  encoded_inputs = encode(inputs)
  with tfe.restore_variables_on_create(ckpt_path):
    model_output = translate_model.infer(encoded_inputs)["outputs"]
  return decode(model_output)

inputs = "Sie ist zurückgetreten."
outputs = translate(inputs)

print("Inputs: %s" % inputs)
print("Outputs: %s" % outputs)

输出如下:

  • 输入:Sie ist zurückgetreten。
  • 输出:Sie sind zurückgetreten。

    翻译似乎仍然是从英语到德语,而不是反之亦然。

    我错过了什么?

  • 4

    1 回答 1

    1

    ckpt_name = "transformer_ende_test"您从检查点(并从中下载)加载的模型gs_ckpt_dir仅针对英语到德语进行了训练。您需要找到以相反方向训练的模型的检查点或自己训练一个。我不知道有任何公开可用的德语到英语 T2T 模型检查点。

    于 2019-09-09T18:51:45.293 回答