我正在处理不平衡的基于文本的数据集。在训练模型时,我使用了 tensorflow 平衡批处理生成器来创建平衡批处理,如下所示:
batch_generator, steps_per_epoch = balanced_batch_generator(training_x, training_y, BATCH,
random_state=42)
def build_model():
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, input_length=max_seq_len,
embeddings_initializer='glorot_uniform'),
tf.keras.layers.GlobalAveragePooling1D(), # average across the vector to flatten it out (faster)
tf.keras.layers.Dense(16, activation='relu', kernel_initializer='glorot_uniform',
kernel_regularizer=tf.keras.regularizers.l2(0.1)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(16, activation='relu', kernel_initializer='glorot_uniform',
kernel_regularizer=tf.keras.regularizers.l2(0.1)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.001),
metrics=metrics)
model.summary()
return model
callback_history = build_model().fit_generator(batch_generator,
steps_per_epoch=steps_per_epoch,
epochs=nu_epochs,
verbose=1)
但出现以下错误:
TypeError: 'int' object is not subscriptable