from typing import Dict, List, Union import numpy as np from datasets import Dataset, load_dataset from easygoogletranslate import EasyGoogleTranslate from langchain.prompts import FewShotPromptTemplate, PromptTemplate LANGAUGE_TO_PREFIX = { "chinese_simplified": "zh-CN", "arabic": "ar", "hindi": "hi", "indonesian": "id", "amharic": "am", "bengali": "bn", "burmese": "my", "uzbek": "uz", "nepali": "ne", "japanese": "ja", "spanish": "es", "turkish": "tr", "persian": "fa", "azerbaijani": "az", "korean": "ko", "hebrew": "he", "telugu": "te", "german": "de", "greek": "el", "tamil": "ta", "assamese": "as", "russian": "ru", "romanian": "ro", "malayalam": "ml", "swahili": "sw", "bulgarian": "bg", "thai": "th", "urdu": "ur", "polish": "pl", "dutch": "nl", "danish": "da", "norwegian": "no", "finnish": "fi", "hungarian": "hu", "czech": "cs", "ukrainian": "uk", "bambara": "bam", "ewe": "ewe", "fon": "fon", "hausa": "hau", "igbo": "ibo", "kinyarwanda": "kin", "chichewa": "nya", "twi": "twi", "yoruba": "yor", "slovak": "sk", "serbian": "sr", "swedish": "sv", "vietnamese": "vi", "italian": "it", "portuguese": "pt", "chinese": "zh", "english": "en", "french": "fr", } def _translate_instruction(basic_instruction: str, target_language: str) -> str: translator = EasyGoogleTranslate( source_language="en", target_language=LANGAUGE_TO_PREFIX[target_language], timeout=10, ) return translator.translate(basic_instruction) def create_instruction(lang: str, instruction_language: str, expected_output: str): basic_instruction = f"""You are an NLP assistant whose purpose is to perform Named Entity Recognition (NER). You will need to give each entity a tag, from the following: PER means a person, ORG means organization. LOC means a location entity. The output should be a list of tuples of the format: ['Tag: Entity', 'Tag: Entity'] for each entity in the sentence. The entities should be in {expected_output} language""" return ( instruction_language if lang == "english" else _translate_instruction(basic_instruction, target_language=lang) ) def load_wikiann_dataset(lang, split, limit): """Loads the xlsum dataset""" dataset = load_dataset("wikiann", LANGAUGE_TO_PREFIX[lang])[split] return dataset.select(np.arange(limit)) def _translate_example( example: Dict[str, str], src_language: str, target_language: str ): translator = EasyGoogleTranslate( source_language=LANGAUGE_TO_PREFIX[src_language], target_language=LANGAUGE_TO_PREFIX[target_language], timeout=30, ) return { "tokens": translator.translate(str(example["tokens"])), "ner_tags": translator.translate(str(example["ner_tags"])), } def choose_few_shot_examples( train_dataset: Dataset, few_shot_size: int, context: List[str], selection_criteria: str, lang: str, ) -> List[Dict[str, Union[str, int]]]: """Selects few-shot examples from training datasets Args: train_dataset (Dataset): Training Dataset few_shot_size (int): Number of few-shot examples selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k] Returns: List[Dict[str, Union[str, int]]]: Selected examples """ selected_examples = [] example_idxs = [] if selection_criteria == "first_k": example_idxs = list(range(few_shot_size)) elif selection_criteria == "random": example_idxs = ( np.random.choice(len(train_dataset), size=few_shot_size, replace=True) .astype(int) .tolist() ) ic_examples = [train_dataset[idx] for idx in example_idxs] ic_examples = [ {"tokens": " ".join(example["tokens"]), "ner_tags": example["spans"]} for example in ic_examples ] for idx, ic_language in enumerate(context): ( selected_examples.append(ic_examples[idx]) if ic_language == lang else ( selected_examples.append( _translate_example( example=ic_examples[idx], src_language=lang, target_language=ic_language, ) ) ) ) return selected_examples def construct_prompt( instruction: str, test_example: dict, zero_shot: bool, dataset: str, num_examples: int, lang: str, config: Dict[str, str], ): if not instruction: instruction = create_instruction(lang, config["prefix"], config["output"]) example_prompt = PromptTemplate( input_variables=["tokens", "ner_tags"], template="Sentence: {tokens}\nNer Tags: {ner_tags}", ) zero_shot_template = f"""{instruction}""" + "\n Sentence: {text} " "" try: test_data = load_wikiann_dataset(lang=lang, split="test", limit=500) except Exception as e: raise KeyError( f"{lang} is not supported in 'wikiAnn' dataset, choose supported language in few-shot" ) ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=num_examples, context=[config["context"]] * num_examples, selection_criteria="random", lang=lang, ) prompt = ( FewShotPromptTemplate( examples=ic_examples, prefix=instruction, example_prompt=example_prompt, suffix=": {text}", input_variables=["text"], ) if not zero_shot else PromptTemplate(input_variables=["text"], template=zero_shot_template) ) if config["input"] != lang: test_example = _translate_example( example=test_example, src_language=lang, target_language=config["input"] ) print(test_example) return prompt.format(text=test_example["tokens"])