nlp_tools / toolbox /allennlp /data /dataset_readers /text_classification_json.py
qgyd2021's picture
add language
1a4c139
from typing import Dict, List, Union
import logging
import json
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField, Field, ListField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Tokenizer, SpacyTokenizer
from allennlp.data.tokenizers.sentence_splitter import SpacySentenceSplitter
logger = logging.getLogger(__name__)
@DatasetReader.register("text_classification_json_utf8")
class TextClassificationJsonReader(DatasetReader):
def __init__(
self,
token_indexers: Dict[str, TokenIndexer] = None,
tokenizer: Tokenizer = None,
segment_sentences: bool = False,
max_sequence_length: int = None,
skip_label_indexing: bool = False,
text_key: str = "text",
label_key: str = "label",
**kwargs,
) -> None:
super().__init__(
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
)
self._tokenizer = tokenizer or SpacyTokenizer()
self._segment_sentences = segment_sentences
self._max_sequence_length = max_sequence_length
self._skip_label_indexing = skip_label_indexing
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self._text_key = text_key
self._label_key = label_key
if self._segment_sentences:
self._sentence_segmenter = SpacySentenceSplitter()
def _read(self, file_path):
with open(cached_path(file_path), "r", encoding="utf-8") as data_file:
for line in self.shard_iterable(data_file.readlines()):
if not line:
continue
items = json.loads(line)
text = items[self._text_key]
label = items.get(self._label_key)
if label is not None:
if self._skip_label_indexing:
try:
label = int(label)
except ValueError:
raise ValueError(
"Labels must be integers if skip_label_indexing is True."
)
else:
label = str(label)
yield self.text_to_instance(text=text, label=label)
def _truncate(self, tokens):
if len(tokens) > self._max_sequence_length:
tokens = tokens[: self._max_sequence_length]
return tokens
def text_to_instance( # type: ignore
self, text: str, label: Union[str, int] = None
) -> Instance:
fields: Dict[str, Field] = {}
if self._segment_sentences:
sentences: List[Field] = []
sentence_splits = self._sentence_segmenter.split_sentences(text)
for sentence in sentence_splits:
word_tokens = self._tokenizer.tokenize(sentence)
if self._max_sequence_length is not None:
word_tokens = self._truncate(word_tokens)
sentences.append(TextField(word_tokens))
fields["tokens"] = ListField(sentences)
else:
tokens = self._tokenizer.tokenize(text)
if self._max_sequence_length is not None:
tokens = self._truncate(tokens)
fields["tokens"] = TextField(tokens)
if label is not None:
fields["label"] = LabelField(label, skip_indexing=self._skip_label_indexing)
return Instance(fields)
def apply_token_indexers(self, instance: Instance) -> None:
if self._segment_sentences:
for text_field in instance.fields["tokens"]: # type: ignore
text_field._token_indexers = self._token_indexers
else:
instance.fields["tokens"]._token_indexers = self._token_indexers # type: ignore