Pradeep Kumar commited on
Commit
c264923
·
verified ·
1 Parent(s): 717d17b

Delete question_answering_dataloader.py

Browse files
Files changed (1) hide show
  1. question_answering_dataloader.py +0 -115
question_answering_dataloader.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Loads dataset for the question answering (e.g, SQuAD) task."""
16
- import dataclasses
17
- from typing import Mapping, Optional
18
-
19
- import tensorflow as tf, tf_keras
20
- from official.common import dataset_fn
21
- from official.core import config_definitions as cfg
22
- from official.core import input_reader
23
- from official.nlp.data import data_loader
24
- from official.nlp.data import data_loader_factory
25
-
26
-
27
- @dataclasses.dataclass
28
- class QADataConfig(cfg.DataConfig):
29
- """Data config for question answering task (tasks/question_answering)."""
30
- # For training, `input_path` is expected to be a pre-processed TFRecord file,
31
- # while for evaluation, it is expected to be a raw JSON file (b/173814590).
32
- input_path: str = ''
33
- global_batch_size: int = 48
34
- is_training: bool = True
35
- seq_length: int = 384
36
- # Settings below are question answering specific.
37
- version_2_with_negative: bool = False
38
- # Settings below are only used for eval mode.
39
- input_preprocessed_data_path: str = ''
40
- doc_stride: int = 128
41
- query_length: int = 64
42
- # The path to the vocab file of word piece tokenizer or the
43
- # model of the sentence piece tokenizer.
44
- vocab_file: str = ''
45
- tokenization: str = 'WordPiece' # WordPiece or SentencePiece
46
- do_lower_case: bool = True
47
- xlnet_format: bool = False
48
- file_type: str = 'tfrecord'
49
-
50
-
51
- @data_loader_factory.register_data_loader_cls(QADataConfig)
52
- class QuestionAnsweringDataLoader(data_loader.DataLoader):
53
- """A class to load dataset for sentence prediction (classification) task."""
54
-
55
- def __init__(self, params):
56
- self._params = params
57
- self._seq_length = params.seq_length
58
- self._is_training = params.is_training
59
- self._xlnet_format = params.xlnet_format
60
-
61
- def _decode(self, record: tf.Tensor):
62
- """Decodes a serialized tf.Example."""
63
- name_to_features = {
64
- 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
65
- 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
66
- 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
67
- }
68
- if self._xlnet_format:
69
- name_to_features['class_index'] = tf.io.FixedLenFeature([], tf.int64)
70
- name_to_features['paragraph_mask'] = tf.io.FixedLenFeature(
71
- [self._seq_length], tf.int64)
72
- if self._is_training:
73
- name_to_features['is_impossible'] = tf.io.FixedLenFeature([], tf.int64)
74
-
75
- if self._is_training:
76
- name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
77
- name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
78
- else:
79
- name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
80
- example = tf.io.parse_single_example(record, name_to_features)
81
-
82
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
83
- # So cast all int64 to int32.
84
- for name in example:
85
- t = example[name]
86
- if t.dtype == tf.int64:
87
- t = tf.cast(t, tf.int32)
88
- example[name] = t
89
-
90
- return example
91
-
92
- def _parse(self, record: Mapping[str, tf.Tensor]):
93
- """Parses raw tensors into a dict of tensors to be consumed by the model."""
94
- x, y = {}, {}
95
- for name, tensor in record.items():
96
- if name in ('start_positions', 'end_positions', 'is_impossible'):
97
- y[name] = tensor
98
- elif name == 'input_ids':
99
- x['input_word_ids'] = tensor
100
- elif name == 'segment_ids':
101
- x['input_type_ids'] = tensor
102
- else:
103
- x[name] = tensor
104
- if name == 'start_positions' and self._xlnet_format:
105
- x[name] = tensor
106
- return (x, y)
107
-
108
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
109
- """Returns a tf.dataset.Dataset."""
110
- reader = input_reader.InputReader(
111
- params=self._params,
112
- dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
113
- decoder_fn=self._decode,
114
- parser_fn=self._parse)
115
- return reader.read(input_context)