Pradeep Kumar commited on
Commit
9dcd3ec
·
verified ·
1 Parent(s): ee402bd

Delete tagging_dataloader.py

Browse files
Files changed (1) hide show
  1. tagging_dataloader.py +0 -90
tagging_dataloader.py DELETED
@@ -1,90 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Loads dataset for the tagging (e.g., NER/POS) task."""
16
- import dataclasses
17
- from typing import Mapping, Optional
18
-
19
- import tensorflow as tf, tf_keras
20
- from official.common import dataset_fn
21
- from official.core import config_definitions as cfg
22
- from official.core import input_reader
23
- from official.nlp.data import data_loader
24
- from official.nlp.data import data_loader_factory
25
-
26
-
27
- @dataclasses.dataclass
28
- class TaggingDataConfig(cfg.DataConfig):
29
- """Data config for tagging (tasks/tagging)."""
30
- is_training: bool = True
31
- seq_length: int = 128
32
- include_sentence_id: bool = False
33
- file_type: str = 'tfrecord'
34
-
35
-
36
- @data_loader_factory.register_data_loader_cls(TaggingDataConfig)
37
- class TaggingDataLoader(data_loader.DataLoader):
38
- """A class to load dataset for tagging (e.g., NER and POS) task."""
39
-
40
- def __init__(self, params: TaggingDataConfig):
41
- self._params = params
42
- self._seq_length = params.seq_length
43
- self._include_sentence_id = params.include_sentence_id
44
-
45
- def _decode(self, record: tf.Tensor):
46
- """Decodes a serialized tf.Example."""
47
- name_to_features = {
48
- 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
49
- 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
50
- 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
51
- 'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
52
- }
53
- if self._include_sentence_id:
54
- name_to_features['sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
55
- name_to_features['sub_sentence_id'] = tf.io.FixedLenFeature([], tf.int64)
56
-
57
- example = tf.io.parse_single_example(record, name_to_features)
58
-
59
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
60
- # So cast all int64 to int32.
61
- for name in example:
62
- t = example[name]
63
- if t.dtype == tf.int64:
64
- t = tf.cast(t, tf.int32)
65
- example[name] = t
66
-
67
- return example
68
-
69
- def _parse(self, record: Mapping[str, tf.Tensor]):
70
- """Parses raw tensors into a dict of tensors to be consumed by the model."""
71
- x = {
72
- 'input_word_ids': record['input_ids'],
73
- 'input_mask': record['input_mask'],
74
- 'input_type_ids': record['segment_ids']
75
- }
76
- if self._include_sentence_id:
77
- x['sentence_id'] = record['sentence_id']
78
- x['sub_sentence_id'] = record['sub_sentence_id']
79
-
80
- y = record['label_ids']
81
- return (x, y)
82
-
83
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
84
- """Returns a tf.dataset.Dataset."""
85
- reader = input_reader.InputReader(
86
- params=self._params,
87
- dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
88
- decoder_fn=self._decode,
89
- parser_fn=self._parse)
90
- return reader.read(input_context)