我正在尝试为一个问题引入一个 mod/mixin 。特别是,我在这里重点关注SpeechRecognitionProblem
. 我打算修改此问题,因此我寻求执行以下操作:
class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):
def hparams(self, defaults, model_hparams):
SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
p = defaults
p.vocab_size['targets'] = vocab_size
def feature_encoders(self, data_dir):
# ...
所以这个做的不多。它hparams()
从基类调用函数,然后更改一些值。
现在,已经有一些现成的问题,例如 Libri Speech:
@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
# ..
但是,为了应用我的修改,我这样做:
@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
# ..
如果我没记错的话,这应该覆盖其中的所有内容(具有相同的签名)Librispeech
,而是调用SpeechRecognitionProblemMod
.
由于我能够使用此代码训练模型,因此我假设它到目前为止按预期工作。
现在我的问题来了:
训练后我想序列化模型。这通常有效。但是,它不适用于我的 mod,我实际上知道为什么:
在某个时刻hparams()
被调用。调试到这一点将向我显示以下内容:
self # {LibrispeechMod}
self.hparams # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>
self.hparams
应该是<bound method SpeechRecognitionProblemMod.hparams of ..>
!似乎由于某种原因直接hparams()
被SpeechRecognitionProblem
调用而不是SpeechRecognitionProblemMod
. 但请注意,它是正确的类型feature_encoders()
!
问题是我知道这在训练期间有效。我可以看到超参数(hparams)被相应地应用,因为模型的图形节点名称通过我的修改而改变。
我需要指出一个专业。tensor2tensor
允许动态加载 a t2t_usr_dir
,这是由 . 加载的附加 python 模块import_usr_dir
。我也在我的序列化脚本中使用了该函数:
if usr_dir:
logging.info('Loading user dir %s' % usr_dir)
import_usr_dir(usr_dir)
这可能是我目前能看到的唯一罪魁祸首,尽管我无法说出为什么这可能会导致问题。
如果有人看到我没有看到的东西,我很乐意在这里得到提示我做错了什么。
那么你得到的错误是什么?
为了完整起见,这是hparams()
调用错误方法的结果:
NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint
symbol_modality_256_256
是错的。它应该是symbol_modality_<vocab-size>_256
在哪里<vocab-size>
设置的词汇量SpeechRecognitionProblemMod.hparams
。