1

总的来说,我对 Keras 和 Tensorflow 很陌生,所以也许这是一个愚蠢的问题......

我想要实现的是:我有一组单词,比如说:猫、狗、牛……这些单词应该根据给定的字母进行编码,在字符的位置上有一个 1向量,否则为 0。对于猫,例如 1,0,1,0,0,0,0,0,....,1,0,0,...0。

为此,我使用 Keras Tokenizer:

tk = Tokenizer(char_level=True, oov_token='UNK')

alphabet="abcdefghijklmnopqrstuvwxyzöäü0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"
char_dict = {}
for i, char in enumerate(alphabet):
    char_dict[char] = i + 1

# Use char_dict to replace the tk.word_index
tk.word_index = char_dict 
# Add 'UNK' to the vocabulary 
tk.word_index[tk.oov_token] = max(char_dict.values()) + 1

x_train = tk.texts_to_matrix(x_train)

这些向量被传递到 Keras 模型中进行预测。现在,我希望转换发生在 Keras 模型中。因此,用户应该向模型提供“cat”,而不是像上面那样的数字向量。并且模型也应该返回“cat”。我怎样才能做到这一点?我看到 Keras 中有一个 Lambda 层,这是正确的做法吗?提前致谢。

编辑澄清:此刻的模型看起来像这样:

model = Sequential()
model.add(Dense(128, input_shape=(len(alphabet)+1,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

但是我想要实现的是有一个输入层,它将字符串作为输入并将字符串转换为实际第一层可以读取的格式。像这样:

model = Sequential()
**model.add(transformation_layer)**
model.add(Dense(128, input_shape=(len(alphabet)+1,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

编辑 2 这是我尝试过的,但在运行“model.fit”函数时出现以下错误:

tf.Tensortensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Graph 执行中不允许迭代。使用 Eager 执行或使用 @tf.function 装饰此函数

def transform_layer(x):
  return tk.texts_to_matrix(x)



print('Building model...')
transform_layer = Lambda(transform_layer)
model = Sequential()
model.add(transform_layer)
model.add(Dense(128, input_shape=(len(alphabet)+1,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

history = model.fit(np.array(['test','test2']), np.array(['blub','blub2']),
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_split=0.1)
4

0 回答 0