我对 Trax 如何输出预测感到困惑。您可以观察到预测只包含 11 个预测,但它们是如何排列的以匹配我的预测标签?
def classifier(vocab_size=len(Vocab), embedding_dim=256, output_dim=11, mode='train'): #Changed the output_dim to 11 output units
### START CODE HERE (Replace instances of 'None' with your code) ###
# create embedding layer
embed_layer = tl.Embedding(
vocab_size=vocab_size, # Size of the vocabulary
d_feature=embedding_dim) # Embedding dimension
# Create a mean layer, to create an "average" word embedding
mean_layer = tl.Mean(axis=1)
# Create a dense layer, one unit for each output
dense_output_layer = tl.Dense(n_units = output_dim)
# Create the Relu layer
relu_layer= tl.Relu()
# Create the log softmax layer (no parameters needed)
softmax_layer = tl.Softmax()
# Use tl.Serial to combine all layers
# and create the classifier
# of type trax.layers.combinators.Serial
model = tl.Serial(
embed_layer, # embedding layer
mean_layer, # mean layer
dense_output_layer, # dense output layer
softmax_layer # softmax layer
)
### END CODE HERE ###
# return the model of type
return model
from trax.supervised import training
batch_size = 16
rnd.seed(271)
train_task = training.TrainTask(
labeled_data=train_generator(batch_size=batch_size, shuffle=True),
loss_layer=tl.WeightedCategoryCrossEntropy(),
optimizer=trax.optimizers.Adam(0.01),
n_steps_per_checkpoint=20,
)
eval_task = training.EvalTask(
labeled_data=val_generator(batch_size=batch_size, shuffle=True),
metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
)
model = classifier()
y_preds=[]
texts=[]
for sent in doc:
inputs = np.array(sent_to_tensor(sent, vocab_dict=Vocab))
#print(sent,inputs)
inputs = inputs[None, :]
#junk=np.where(inputs[0].size==(0))
#inputs=np.delete(inputs,junk,0)
#print(junk)
try:
predictions=model(inputs)
y_preds.append(predictions)
texts.append(sent)
print(f'example input_str: {sent}')
#print(f'Model returned sentiment probabilities: {predictions.argmax(axis=1)}')
print(f'Model returned sentiment probabilities: {predictions}')
except:
pass
# We can observe that it now displays 11 output predictions
上述代码的输出是:
example input_str: Hello, is this apple product support?
Model returned sentiment probabilities: [[2.2228214e-10 9.9999809e-01 2.1072233e-07 3.0890752e-07 2.4897969e-07 2.2614840e-07 3.2589759e-07 2.4826460e-07 2.3848202e-07 2.1718901e-07 2.8736207e-07]]
trax 输出预测如何排序以匹配我的标签?