1
def create_hparams():
    return trainer_lib.create_hparams(
      FLAGS.hparams_set,
      FLAGS.hparams,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      problem_name=FLAGS.problem)


def create_decode_hparams():
    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.shards = FLAGS.decode_shards
    decode_hp.shard_id = FLAGS.worker_id
    decode_in_memory = FLAGS.decode_in_memory or decode_hp.decode_in_memory
    decode_hp.decode_in_memory = decode_in_memory
    decode_hp.decode_to_file = FLAGS.decode_to_file
    decode_hp.decode_reference = FLAGS.decode_reference
    return decode_hp

hp = create_hparams()
decode_hp = create_decode_hparams()
run_conf = t2t_trainer.create_run_config(hp)
estimator = trainer_lib.create_estimator(
    FLAGS.model,
    hp,
    run_conf,
    decode_hparams=decode_hp,
    use_tpu=FLAGS.use_tpu)
print(run_conf.session_config)

def input_fn():
    inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
    input_tensor = {'inputs': inputs }
    return tf.estimator.export.ServingInputReceiver(input_tensor, input_tensor)

predictor=tf.contrib.predictor.from_estimator(estimator, input_fn)

我得到了输出

InvalidArgumentError:无法为操作转换器分配设备/body/parallel_0/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention:无法满足显式设备规范“/device:GPU:0”,因为没有支持 GPU 设备的内核可用的。托管调试信息:托管组具有以下类型和支持的设备:根成员(assigned_device_name_index_=-1 requested_device_name_='/device:GPU:0'assigned_device_name_=''resource_device_name_=''supported_device_types_=[CPU] possible_devices_=[] ImageSummary:中央处理器

托管成员、用户请求的设备和框架分配的设备(如果有):
transformer/body/parallel_0/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention (ImageSummary) /device:GPU:0

Op: ImageSummary 节点属性: max_images=1, T=DT_FLOAT, bad_color=Tensor 注册内核: device='CPU'

当我打印 run_conf.session_config 时,我得到了 allow_soft_placement:true。很多人说它可以解决 InvalidArgumentError 的问题,但似乎对我不起作用。

4

0 回答 0