0

我对 Trax 如何输出预测感到困惑。您可以观察到预测只包含 11 个预测,但它们是如何排列的以匹配我的预测标签?

def classifier(vocab_size=len(Vocab), embedding_dim=256, output_dim=11, mode='train'): #Changed the output_dim to 11 output units
        
### START CODE HERE (Replace instances of 'None' with your code) ###
    # create embedding layer
    embed_layer = tl.Embedding(
        vocab_size=vocab_size, # Size of the vocabulary
        d_feature=embedding_dim)  # Embedding dimension
    
    # Create a mean layer, to create an "average" word embedding
    mean_layer = tl.Mean(axis=1)
    
    # Create a dense layer, one unit for each output
    dense_output_layer = tl.Dense(n_units = output_dim)

    # Create the Relu layer
    relu_layer= tl.Relu()
    # Create the log softmax layer (no parameters needed)
    softmax_layer = tl.Softmax()
    
    # Use tl.Serial to combine all layers
    # and create the classifier
    # of type trax.layers.combinators.Serial
    model = tl.Serial(
      embed_layer, # embedding layer
      mean_layer, # mean layer
      dense_output_layer, # dense output layer  
      softmax_layer #  softmax layer
    )
### END CODE HERE ###     
    
    # return the model of type
    return model

from trax.supervised import training

batch_size = 16
rnd.seed(271)

train_task = training.TrainTask(
    labeled_data=train_generator(batch_size=batch_size, shuffle=True),
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=20,
)

eval_task = training.EvalTask(
    labeled_data=val_generator(batch_size=batch_size, shuffle=True),
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
)

model = classifier()


y_preds=[]
texts=[]
for sent in doc:
    inputs = np.array(sent_to_tensor(sent, vocab_dict=Vocab))
    #print(sent,inputs)
    inputs = inputs[None, :]
    #junk=np.where(inputs[0].size==(0))
    #inputs=np.delete(inputs,junk,0)
    #print(junk)
    try:
        predictions=model(inputs)
        y_preds.append(predictions)
        texts.append(sent)
        print(f'example input_str: {sent}')
    #print(f'Model returned sentiment probabilities: {predictions.argmax(axis=1)}')
        print(f'Model returned sentiment probabilities: {predictions}')
    except:
        pass
# We can observe that it now displays 11 output predictions

上述代码的输出是:

example input_str: Hello, is this apple product support?

Model returned sentiment probabilities: [[2.2228214e-10 9.9999809e-01 2.1072233e-07 3.0890752e-07 2.4897969e-07 2.2614840e-07 3.2589759e-07 2.4826460e-07 2.3848202e-07 2.1718901e-07 2.8736207e-07]]

trax 输出预测如何排序以匹配我的标签?

4

0 回答 0