def create_hparams():
return trainer_lib.create_hparams(
FLAGS.hparams_set,
FLAGS.hparams,
data_dir=os.path.expanduser(FLAGS.data_dir),
problem_name=FLAGS.problem)
def create_decode_hparams():
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
decode_hp.shards = FLAGS.decode_shards
decode_hp.shard_id = FLAGS.worker_id
decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
decode_hp.decode_in_memory = decode_in_memory
decode_hp.decode_to_file = FLAGS.decode_to_file
decode_hp.decode_reference = FLAGS.decode_reference
return decode_hp
hp = create_hparams()
decode_hp = create_decode_hparams()
run_conf = t2t_trainer.create_run_config(hp)
estimator = trainer_lib.create_estimator(
FLAGS.model,
hp,
run_conf,
decode_hparams=decode_hp,
use_tpu=FLAGS.use_tpu)
print(run_conf.session_config)
def input_fn():
inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
input_tensor = {'inputs': inputs }
return tf.estimator.export.ServingInputReceiver(input_tensor, input_tensor)
predictor=tf.contrib.predictor.from_estimator(estimator, input_fn)
我得到了输出
InvalidArgumentError:无法为操作转换器分配设备/body/parallel_0/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention:无法满足显式设备规范“/device:GPU:0”,因为没有支持 GPU 设备的内核可用的。托管调试信息:托管组具有以下类型和支持的设备:根成员(assigned_device_name_index_=-1 requested_device_name_='/device:GPU:0'assigned_device_name_=''resource_device_name_=''supported_device_types_=[CPU] possible_devices_=[] ImageSummary:中央处理器
托管成员、用户请求的设备和框架分配的设备(如果有):
transformer/body/parallel_0/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention (ImageSummary) /device:GPU:0Op: ImageSummary 节点属性: max_images=1, T=DT_FLOAT, bad_color=Tensor 注册内核: device='CPU'
当我打印 run_conf.session_config 时,我得到了 allow_soft_placement:true。很多人说它可以解决 InvalidArgumentError 的问题,但似乎对我不起作用。