0

我使用 keras 编写了一个程序,用于检测假冒的真实文本(我使用了 5000 个训练数据和 10,000 个测试数据),我使用了 Transformer 和“distilbert-base-uncased”模型进行检测。现在我决定使用网格搜索来调整超参数,但遇到了以下错误:

    TuneError                                 Traceback (most recent call last)
    <ipython-input-15-c4a44a2180d8> in <module>()
        156     tune_iris,
        157     verbose=1,
    --> 158     config=hyperparameter_space,
        159    )
        160 
    
    /usr/local/lib/python3.6/dist-packages/ray/tune/tune.py in run(run_or_experiment, name, stop, config, resources_per_trial, num_samples, local_dir, upload_dir, trial_name_creator, loggers, sync_to_cloud, sync_to_driver, checkpoint_freq, checkpoint_at_end, sync_on_checkpoint, keep_checkpoints_num, checkpoint_score_attr, global_checkpoint_period, export_formats, max_failures, fail_fast, restore, search_alg, scheduler, with_server, server_port, verbose, progress_reporter, resume, queue_trials, reuse_actors, trial_executor, raise_on_failed_trial, return_trials, ray_auto_init)
        354     if incomplete_trials:
        355         if raise_on_failed_trial:
    --> 356             raise TuneError("Trials did not complete", incomplete_trials)
        357         else:
        358             logger.error("Trials did not complete: %s", incomplete_trials)
    
    TuneError: ('Trials did not complete', [tune_iris_83131_00000, tune_iris_83131_00001, tune_iris_83131_00002, tune_iris_83131_00003, tune_iris_83131_00004, tune_iris_83131_00005, tune_iris_83131_00006, tune_iris_83131_00007, tune_iris_83131_00008, tune_iris_83131_00009, tune_iris_83131_00010, tune_iris_83131_00011, tune_iris_83131_00012, tune_iris_83131_00013, tune_iris_83131_00014, tune_iris_83131_00015, tune_iris_83131_00016, tune_iris_83131_00017])

我写的程序如下:

data = pd.concat([train_webtext,train_gen,valid_webtext,valid_gen])

sentences=data['text']
labels=labels1+labels2
len(sentences),len(labels)


DistilBertTokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased",do_lower_case=False)


input_ids=[]
attention_masks=[]

for sent in sentences:
    bert_inp=DistilBertTokenizer.encode_plus(sent,add_special_tokens = True,max_length =64,pad_to_max_length = True,return_attention_mask = True)
    input_ids.append(bert_inp['input_ids'])
    attention_masks.append(bert_inp['attention_mask'])
    

input_ids=np.asarray(input_ids)
attention_masks=np.array(attention_masks)
labels=np.array(labels)


class TuneReporterCallback(keras.callbacks.Callback):
    """Tune Callback for Keras.
    
    The callback is invoked every epoch.
    """

    def __init__(self, logs={}):
        self.iteration = 0
        super(TuneReporterCallback, self).__init__()

    def on_epoch_end(self, batch, logs={}):
        self.iteration += 1
        tune.report(keras_info=logs, mean_accuracy=logs.get("accuracy"), mean_loss=logs.get("loss"))


def tune_gpt(config):
  train_inp,val_inp,train_label,val_label,train_mask,val_mask=train_test_split(input_ids,labels,attention_masks,test_size=0.6666666666666666)
  DistilBert_model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased',num_labels=2)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  optimizer = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"],epsilon=1e-08)
  DistilBert_model.compile(loss=loss,optimizer=optimizer,metrics=[metric])
  checkpoint_callback = [tf.keras.callbacks.ModelCheckpoint( "DistilBert_model.h5",monitor='val_loss',mode='min',save_best_only=True)]
  callbacks = [checkpoint_callback, TuneReporterCallback()]
  history=DistilBert_model.fit([train_inp,train_mask],train_label,batch_size=config["batch_size"],epochs=config["epochs"],validation_data=([val_inp,val_mask],val_label),callbacks=callbacks)
  assert len(inspect.getargspec(tune_gpt).args) == 1, "The `tune_gpt` function needs to take in the arg `config`."



hyperparameter_space  ={
       "batch_size": tune.grid_search([16, 32]),
       "learning_rate": tune.grid_search([2e-5, 3e-5, 5e-5]),
       "epochs": tune.grid_search([2, 3, 4])
    }


analysis = tune.run(
    tune_gpt, 
    verbose=1, 
    config=hyperparameter_space,
   )
4

1 回答 1

0

您的代码似乎有一些错误,但由于详细选项,没有出现详细的错误消息。

请更改详细选项

verbose=1

verbose=3

查看详细错误。

(详细模式。0 = 静音,1 = 仅状态更新,2 = 状态和简要试验结果,3 = 状态和详细试验结果。默认为 3。)

于 2022-02-06T09:05:13.867 回答