0

我正在使用AllenNLP来训练分层注意力网络模型。我的训练数据集包含一个JSON 对象列表(例如,列表中的每个对象都是一个带有键的JSON 对象:= ["text", "label"]。与文本键关联的值是一个列表列表,例如:

[{"text":[["i", "feel", "sad"], ["not", "sure", "i", "guess", "the", "weather"]], "label":0} ... {"text":[[str]], "label":int}] 

我的 DatasetReader 类看起来像:

@DatasetReader.register("my_reader")
class TranscriptDataReader(DatasetReader):
    def __init__(self,
                 token_indexers: Optional[Dict[str, TokenIndexer]] = None,
                 lazy: bool = True) -> None:
        super().__init__(lazy)
        self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}

    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path, 'r') as f:
            data = json.loads(f.read())
            for _,data_json in enumerate(data):
                sent_list = []
                for segment in data_json["text"]:
                    sent_list.append(self.get_text_field(segment))
                yield self.create_instance(sent_list, str(data_json["label"]))

    def get_text_field(self, segment):
        return TextField([Token(token.lower()) for token in segment],self._token_indexers)


    def create_instance(self, sent_list, label):
        label_field = LabelField(label, skip_indexing=False)
        fields = {'tokens': ListField(sent_list), 'label': label_field}
        return Instance(fields)

在我的配置文件中,我有:

{
  dataset_reader: {
    type: 'my_reader',
  },

  train_data_path: 'data/train.json',
  validation_data_path: 'data/dev.json',

 data_loader: {
    batch_sampler: {
      type: 'bucket',
      batch_size: 10
    }
 },

我已经尝试(或者)将lazy数据集读取器的参数设置为Trueand False

  • 当设置为True时,模型能够进行训练,但是,我观察到只有一列火车和一个开发实例实际被加载,而我的数据集包含 ~100。
  • 当设置为 时False,我已将yield行修改_readreturn; 但是,这会导致基本词汇类中的类型错误。我也尝试yield在设置为时保持原样False;在这种情况下,根本没有加载任何实例,并且由于实例集是空的,因此词汇表不会被实例化,并且嵌入类会引发错误。

将不胜感激指针和/或调试提示。

4

2 回答 2

1

如果您使用allennlp>=v2.0.0,则不推荐使用构造函数lazy中的参数。DatasetReader因此,您super().__init__(lazy)将被解释为新的构造函数参数max_instances,即max_instances=True等效于max_instances=1.

于 2021-02-16T06:01:39.140 回答
0

您能否打印并告诉我们在读取 json 文件后加载了多少实例(为清楚起见,在下面添加了打印命令)

def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path, 'r') as f:
            data = json.loads(f.read())
            print(len(data))
            for _,data_json in enumerate(data):
               sent_list = []
                for segment in data_json["text"]:
                    sent_list.append(self.get_text_field(segment))
                yield self.create_instance(sent_list, str(data_json["label"]))
于 2021-02-09T00:04:19.857 回答