我正在尝试使用 python SDk 和 tensorflow 在 sagemaker 上进行测试分类。我可以修改此https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_abalone_age_predictor_using_keras/abalone.py并运行它,但是当我更改拱门以包含嵌入层时,我得到错误
“Fetch 参数不能解释为张量。(Tensor Tensor("first-layer/embeddings:0", shape=(*, ), dtype=float32_ref) 不是此图中的元素。”
当我将它作为独立模型运行时,它运行完美。这是独立模型的拱门
model = Sequential()
model.add(Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False))
model.add(Conv1D(64, kernel_size=10, padding='same', activation='relu'))
model.add(Conv1D(64, kernel_size=15, padding='same', activation='selu'))
model.add(Conv1D(128, kernel_size=15, padding='same', activation='relu'))
model.add(Conv1D(64, kernel_size=25, padding='same', activation='softmax'))
model.add(Conv1D(128, kernel_size=15, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(2, activation='softmax'))
这是我用于 sagemaker 的 model_fn:
embedding = tf.keras.layers.Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False, name='first-layer')(features[INPUT_TENSOR_NAME])
first = tf.keras.layers.Conv1D(64, kernel_size=10, padding='same', activation='relu')(embedding)
second = tf.keras.layers.Conv1D(64, kernel_size=15, padding='same', activation='relu')(first)
third = tf.keras.layers.Conv1D(128, kernel_size=15, padding='same', activation='relu')(second)
fourth = tf.keras.layers.Conv1D(64, kernel_size=25, padding='same', activation='softmax')(third)
fifth = tf.keras.layers.Conv1D(128, kernel_size=15, padding='same', activation='relu')(fourth)
sixth = tf.keras.layers.BatchNormalization()(fifth)
output = tf.keras.layers.Flatten()(sixth)
output_layer = tf.keras.layers.Dense(2, activation='softmax'))(output)
输入尺寸或值没有问题,如果我只用一个简单的密集层拱门替换这个拱门,代码就可以完美运行。
我已经在 TensorFlow 上尝试过解决方案:张量不是此图的元素,但出现新错误
输入图和层图不一样:Tensor("random_shuffle_queue_DequeueMany:1", shape=(128, 200), dtype=float32, device=/device:CPU:0) 不是来自传入的图。*