Spaces:
Runtime error
Runtime error
File size: 6,425 Bytes
d27fe32 208053f 15f5208 d27fe32 208053f d27fe32 208053f 15f5208 208053f d27fe32 208053f 15f5208 208053f 15f5208 208053f 15f5208 d27fe32 15f5208 d27fe32 15f5208 d27fe32 15f5208 d27fe32 15f5208 208053f d27fe32 208053f 15f5208 208053f 15f5208 d27fe32 208053f 15f5208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
from typing import Dict, List, Union
import numpy as np
from datasets import Dataset, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
LANGAUGE_TO_PREFIX = {
"chinese_simplified": "zh-CN",
"arabic": "ar",
"hindi": "hi",
"indonesian": "id",
"amharic": "am",
"bengali": "bn",
"burmese": "my",
"uzbek": "uz",
"nepali": "ne",
"japanese": "ja",
"spanish": "es",
"turkish": "tr",
"persian": "fa",
"azerbaijani": "az",
"korean": "ko",
"hebrew": "he",
"telugu": "te",
"german": "de",
"greek": "el",
"tamil": "ta",
"assamese": "as",
"russian": "ru",
"romanian": "ro",
"malayalam": "ml",
"swahili": "sw",
"bulgarian": "bg",
"thai": "th",
"urdu": "ur",
"polish": "pl",
"dutch": "nl",
"danish": "da",
"norwegian": "no",
"finnish": "fi",
"hungarian": "hu",
"czech": "cs",
"ukrainian": "uk",
"bambara": "bam",
"ewe": "ewe",
"fon": "fon",
"hausa": "hau",
"igbo": "ibo",
"kinyarwanda": "kin",
"chichewa": "nya",
"twi": "twi",
"yoruba": "yor",
"slovak": "sk",
"serbian": "sr",
"swedish": "sv",
"vietnamese": "vi",
"italian": "it",
"portuguese": "pt",
"chinese": "zh",
"english": "en",
"french": "fr",
}
def _translate_instruction(basic_instruction: str, target_language: str) -> str:
translator = EasyGoogleTranslate(
source_language="en",
target_language=LANGAUGE_TO_PREFIX[target_language],
timeout=10,
)
return translator.translate(basic_instruction)
def create_instruction(lang: str, instruction_language: str, expected_output: str):
basic_instruction = f"""You are an NLP assistant whose
purpose is to perform Named Entity Recognition
(NER). You will need to give each entity a tag, from the following:
PER means a person, ORG means organization.
LOC means a location entity.
The output should be a list of tuples of the format:
['Tag: Entity', 'Tag: Entity'] for each entity in the sentence.
The entities should be in {expected_output} language"""
return (
instruction_language
if lang == "english"
else _translate_instruction(basic_instruction, target_language=lang)
)
def load_wikiann_dataset(lang, split, limit):
"""Loads the xlsum dataset"""
dataset = load_dataset("wikiann", LANGAUGE_TO_PREFIX[lang])[split]
return dataset.select(np.arange(limit))
def _translate_example(
example: Dict[str, str], src_language: str, target_language: str
):
translator = EasyGoogleTranslate(
source_language=LANGAUGE_TO_PREFIX[src_language],
target_language=LANGAUGE_TO_PREFIX[target_language],
timeout=30,
)
return {
"tokens": translator.translate(str(example["tokens"])),
"ner_tags": translator.translate(str(example["ner_tags"])),
}
def choose_few_shot_examples(
train_dataset: Dataset,
few_shot_size: int,
context: List[str],
selection_criteria: str,
lang: str,
) -> List[Dict[str, Union[str, int]]]:
"""Selects few-shot examples from training datasets
Args:
train_dataset (Dataset): Training Dataset
few_shot_size (int): Number of few-shot examples
selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
Returns:
List[Dict[str, Union[str, int]]]: Selected examples
"""
selected_examples = []
example_idxs = []
if selection_criteria == "first_k":
example_idxs = list(range(few_shot_size))
elif selection_criteria == "random":
example_idxs = (
np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
.astype(int)
.tolist()
)
ic_examples = [train_dataset[idx] for idx in example_idxs]
ic_examples = [
{"tokens": " ".join(example["tokens"]), "ner_tags": example["spans"]}
for example in ic_examples
]
for idx, ic_language in enumerate(context):
(
selected_examples.append(ic_examples[idx])
if ic_language == lang
else (
selected_examples.append(
_translate_example(
example=ic_examples[idx],
src_language=lang,
target_language=ic_language,
)
)
)
)
return selected_examples
def construct_prompt(
instruction: str,
test_example: dict,
zero_shot: bool,
dataset: str,
num_examples: int,
lang: str,
config: Dict[str, str],
):
if not instruction:
instruction = create_instruction(lang, config["prefix"], config["output"])
example_prompt = PromptTemplate(
input_variables=["tokens", "ner_tags"],
template="Sentence: {tokens}\nNer Tags: {ner_tags}",
)
zero_shot_template = f"""{instruction}""" + "\n Sentence: {text} " ""
try:
test_data = load_wikiann_dataset(lang=lang, split="test", limit=500)
except Exception as e:
raise KeyError(
f"{lang} is not supported in 'wikiAnn' dataset, choose supported language in few-shot"
)
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=num_examples,
context=[config["context"]] * num_examples,
selection_criteria="random",
lang=lang,
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="<Text>: {text}",
input_variables=["text"],
)
if not zero_shot
else PromptTemplate(input_variables=["text"], template=zero_shot_template)
)
if config["input"] != lang:
test_example = _translate_example(
example=test_example, src_language=lang, target_language=config["input"]
)
print(test_example)
return prompt.format(text=test_example["tokens"])
|