selective_pre_translation / generate_prompt.py
Anonymous
add files
208053f
raw
history blame
19.4 kB
import csv
import enum
import json
import logging
import os
import re
import string
import sys
import unicodedata
from typing import Any, Dict, List, NewType, Union
import numpy as np
import openai
import pandas as pd
import requests
import yaml
from datasets import Dataset, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from tqdm import tqdm
from yaml.loader import SafeLoader
from selective_pre_translation.tasks import qa, summarization, ner, nli
# from models.model_completion import gpt3x_completion, gemini_completion
class LanguageType(enum.Enum):
Low = "Low"
High = "High"
class ModelType(enum.Enum):
English = "English"
Multilingual = "Multilingual"
def get_entities_gpt3_long(prompt):
response = openai.ChatCompletion.create(
engine="chatgpt", temperature=0, messages=[{"role": "user", "content": prompt}]
)
return response["choices"][0]["message"]["content"]
def gpt3x_completion(
prompt: Union[str, List[Dict[str, str]]],
) -> str:
import os
import openai
os.environ["OPENAI_API_KEY"] = ''
def get_entities_chatGPT(final_prompt):
response = openai.ChatCompletion.create(
engine="gpt35-16k",
temperature=0,
messages=[
{"role": "user", "content": final_prompt}
]
)
return response['choices'][0]['message']['content']
return get_entities_chatGPT(final_prompt=prompt)
def mixtral_completion(prompt):
url = "https://api.together.xyz/v1/chat/completions"
# Define your Together API key
together_api_key = "" # Replace with your actual API key
# Define the request payload
payload = {
"temperature": 0,
"max_tokens": 30,
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"messages": [{"role": "user", "content": f"{prompt}"}],
}
# Define request headers
headers = {
"Authorization": f"Bearer {together_api_key}",
"Content-Type": "application/json",
}
# Send POST request
response = requests.post(url, json=payload, headers=headers)
# Check response status
if response.status_code == 200:
# Print the response content (API output)
return response.json()["choices"][0]["message"]["content"]
else:
# Print error message if request fails
print(f"Error: {response.status_code} - {response.text}")
XQUAD_LANG2CODES = {
"bengali": "bn",
"korean": "ko",
"swahili": "sw",
"english": "en",
"indonesian": "id",
"arabic": "ar",
"finnish": "fi",
"telugu": "te",
"russian": "ru",
"german": "de",
"greek": "el",
"hindi": "hi",
"vietnamese": "vi",
"romanian": "ro",
}
INDICQA_LANG2CODES = {
"indicqa": "as",
"bengali": "bn",
"gujarati": "gu",
"hindi": "hi",
"kannada": "kn",
"malayalam": "ml",
"marathi": "mr",
"odia": "or",
"punjabi": "pa",
"tamil": "ta",
"telugu": "te",
"assamese": "as",
}
PUNCT = {
chr(i)
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
}.union(string.punctuation)
WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"]
MIXED_SEGMENTATION_LANGS = ["zh"]
TYDIQA_LANG2CODES = {
"bengali": "bn",
"korean": "ko",
"swahili": "sw",
"english": "en",
"indonesian": "id",
"arabic": "ar",
"finnish": "fi",
"telugu": "te",
"russian": "ru",
"assamese": "as",
"persian": "fa",
}
logger = logging.Logger("Xlsum_task")
LANGUAGE_TO_SUFFIX = {
"chinese_simplified": "zh-CN",
"french": "fr",
"portuguese": "pt",
"english": "en",
"arabic": "ar",
"hindi": "hi",
"indonesian": "id",
"amharic": "am",
"bengali": "bn",
"telugu": "te",
"burmese": "my",
"german": "de",
"greek": "el",
"tamil": "ta",
"assamese": "as",
"hindi": "hi",
"vietnamese": "vi",
"russian": "ru",
"telugu": "te",
"romanian": "ro",
"malayalam": "ml",
"persian": "fa",
}
PARAMS = NewType("PARAMS", Dict[str, Any])
def read_parameters(args_path) -> PARAMS:
with open(args_path) as f:
args = yaml.load(f, Loader=SafeLoader)
return args
def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5):
if dataset_name == "indicqa":
if split != "train":
dataset = load_dataset(
"ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}"
)[split]
else:
dataset = load_dataset("squad_v2")[split]
elif dataset_name == "xquad":
if split != "train":
dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[
"validation"
]
else:
dataset = load_dataset("squad")[split]
elif dataset_name == "tydiqa":
dataset = load_dataset("tydiqa", "secondary_task")[split]
dataset = dataset.map(
lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]}
)
dataset = dataset.filter(lambda example: example["lang"] == lang)
elif dataset_name == "mlqa":
if split == "train":
print("No Training Data for MLQA, switching to validation!")
split = "validation"
if translate_test:
dataset_name = f"mlqa-translate-test.{lang}"
else:
dataset_name = f"mlqa.{lang}.{lang}"
dataset = load_dataset("mlqa", dataset_name)[split]
else:
raise NotImplementedError()
return dataset.select(np.arange(limit))
def construct_prompt(
instruction: str,
test_example: dict,
ic_examples: List[dict],
zero_shot: bool,
lang: str,
config: Dict[Any, Any],
):
example_prompt = PromptTemplate(
input_variables=["context", "question", "answers"],
template="Context: {context}\nQuestion: {question}\n" "Answers: {answers}",
)
zero_shot_template = (
f"""{instruction}""" + "\n<Context>: {context} \n<Question>: {question} " ""
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="<Context>: {context} \n<Question>: {question} \nAnswers: ?",
input_variables=["question", "context"],
)
if not zero_shot
else PromptTemplate(
input_variables=["question", "context"], template=zero_shot_template
)
)
label = test_example["answers"]
if config["input"] != lang:
test_example = _translate_example(
example=test_example, src_language=lang, target_language=config["input"]
)
return (
prompt.format(
question=test_example["question"], context=test_example["context"]
),
label,
)
def dump_metrics(
lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str
):
# Check if the metric logger file exists
file_exists = os.path.exists(metric_logger_path)
# Open the CSV file in append mode
with open(metric_logger_path, "a", newline="") as f:
csvwriter = csv.writer(f, delimiter=",")
# Write header row if the file is newly created
if not file_exists:
header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"]
csvwriter.writerow(header)
csvwriter.writerow(
[
lang,
config["prefix"],
config["input"],
config["context"][0],
config["output"],
f1,
em,
]
)
def dump_predictions(idx, response, label, response_logger_file):
obj = {"q_idx": idx, "prediction": response, "label": label}
with open(response_logger_file, "a") as f:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
def _translate_instruction(basic_instruction: str, target_language: str) -> str:
translator = EasyGoogleTranslate(
source_language="en",
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=50,
)
return translator.translate(basic_instruction)
def _translate_prediction_to_output_language(
prediction: str, prediction_language: str, output_language: str
) -> str:
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[prediction_language],
target_language=LANGUAGE_TO_SUFFIX[output_language],
timeout=10,
)
return translator.translate(prediction)
def create_instruction(lang: str, expected_output: str):
basic_instruction = (
"Answer to the <Question> below, based only to the given <Context>, Follow these instructions:\n"
"1. The answer should include only words from the given context\n"
"2. The answer must include up to 5 words\n"
"3. The answer Should be the shortest as possible\n"
f"4. The answer must be in {expected_output} only!, not another language!!!"
)
return (
basic_instruction
if lang == "english"
else _translate_instruction(basic_instruction, target_language=lang)
)
def _translate_example(
example: Dict[str, str], src_language: str, target_language: str
):
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[str(src_language).lower()],
target_language=LANGUAGE_TO_SUFFIX[str(target_language).lower()],
timeout=30,
)
return {
"question": translator.translate(example["question"]),
"context": translator.translate(example["context"][:2000])
+ translator.translate(example["context"][2000:4000])
+ translator.translate(example["context"][4000:6000]),
"answers": translator.translate(example["answers"][0]),
}
# except Exception as e:
# print(example["text"])
# print(example["summary"])
# print(e)
def choose_few_shot_examples(
train_dataset: Dataset,
few_shot_size: int,
context: List[str],
selection_criteria: str,
lang: str,
) -> List[Dict[str, Union[str, int]]]:
"""Selects few-shot examples from training datasets
Args:
train_dataset (Dataset): Training Dataset
few_shot_size (int): Number of few-shot examples
selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
Returns:
List[Dict[str, Union[str, int]]]: Selected examples
"""
selected_examples = []
example_idxs = []
if selection_criteria == "first_k":
example_idxs = list(range(few_shot_size))
elif selection_criteria == "random":
example_idxs = (
np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
.astype(int)
.tolist()
)
ic_examples = [
{
"question": train_dataset[idx]["question"],
"context": train_dataset[idx]["context"],
"answers": train_dataset[idx]["answers"]["text"],
}
for idx in example_idxs
]
for idx, ic_language in enumerate(context):
(
selected_examples.append(ic_examples[idx])
if ic_language == lang
else (
selected_examples.append(
_translate_example(
example=ic_examples[idx],
src_language=lang,
target_language=ic_language,
)
)
)
)
return selected_examples
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(PUNCT) # set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def process_test_example(
test_data, config_header, idx, test_example, config, zero_shot, lang, params
):
try:
# Your existing code for processing each test example
instruction = create_instruction(
lang=config["prefix"], expected_output=config["output"]
)
text_example = {
"question": test_example["question"],
"context": test_example["context"],
"answers": test_example["answers"]["text"],
}
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=len(config["context"]),
context=config["context"],
selection_criteria="random",
lang=params["selected_language"],
)
prompt, label = construct_prompt(
instruction=instruction,
test_example=text_example,
ic_examples=ic_examples,
zero_shot=zero_shot,
lang=lang,
config=config,
)
pred = gpt3x_completion(prompt=prompt)
print(pred)
logger.info("Saving prediction to persistent volume")
os.makedirs(
f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
)
dump_predictions(
idx=idx,
response=pred,
label=label,
response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
)
except Exception as e:
# Handle exceptions here
print(f"Error processing example {idx}: {e}")
def run_one_configuration(selected_language, config, zero_shot, dataset_name, limit=10):
test_data = load_qa_dataset(
dataset_name=dataset_name,
lang=selected_language,
split="validation" if dataset_name == "xquad" else "test",
limit=limit,
)
for idx, test_example in (pbar := tqdm(enumerate(test_data))):
try:
instruction = create_instruction(
lang=config["prefix"], expected_output=config["output"]
)
text_example = {
"question": test_example["question"],
"context": test_example["context"],
"answers": test_example["answers"]["text"],
}
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=len(config["context"]),
context=config["context"],
selection_criteria="random",
lang=selected_language,
)
prompt, label = construct_prompt(
instruction=instruction,
test_example=text_example,
ic_examples=ic_examples,
zero_shot=zero_shot,
lang=selected_language,
config=config,
)
pred = gpt3x_completion(prompt=prompt)
return pred
except Exception as e:
print(f"Found an exception {e}, continue to the next example")
continue
QA = "QA"
SUMMARIZATION = "Summarization"
NLI = "NLI"
NER = "NER"
def construct_generic_prompt(task, instruction, test_example, zero_shot, num_examples, selected_language, dataset,
config):
print(task)
if task == SUMMARIZATION:
prompt = summarization.construct_prompt(
instruction=instruction,
test_example=test_example,
zero_shot=zero_shot,
dataset=dataset,
num_examples=num_examples,
lang=str(selected_language).lower(),
config=config,
)
elif task == NER:
prompt = ner.construct_prompt(
instruction=instruction,
test_example=test_example,
zero_shot=zero_shot,
num_examples=num_examples,
lang=str(selected_language).lower(),
config=config,
)
elif task == QA:
prompt = qa.construct_prompt(
instruction=instruction,
test_example=test_example,
zero_shot=zero_shot,
num_examples=num_examples,
lang=str(selected_language).lower(),
config=config,
# dataset_name=dataset
)
else:
prompt = nli.construct_prompt(
instruction=instruction,
test_example=test_example,
zero_shot=zero_shot,
num_examples=num_examples,
lang=str(selected_language).lower(),
config=config,
)
return prompt
def _get_language_type(language: str):
df = pd.read_csv("utils/languages_by_word_count.csv")
number_of_words = df[df['Language'] == language]['number of words'].iloc[0]
print(number_of_words)
return LanguageType.Low if number_of_words < 150276400 else LanguageType.High
class Config:
def __init__(self, prefix="source", context="source", examples="source", output="source"):
self.prefix = prefix
self.context = context
self.examples = examples
self.output = output
def set(self, prefix=None, context=None, examples=None, output=None):
if prefix: self.prefix = prefix
if context: self.context = context
if examples: self.examples = examples
if output: self.output = output
def to_dict(self):
return {
'prefix': self.prefix,
'context': self.context,
'examples': self.examples,
'output': self.output
}
def recommend_config(task, lang, model_type):
print(task)
print(model_type)
language_type = _get_language_type(lang)
config = Config()
print(language_type)
if task == QA:
if model_type == ModelType.English.value:
config.set(prefix='source', context='source', examples='source', output='source')
else:
config.set(prefix='english', context='source', examples='source', output='source')
if task == NER:
if model_type == ModelType.English.value:
config.set(prefix='source', context='source', examples='source', output='source')
elif language_type == LanguageType.High:
config.set(prefix='english', context='source', examples='source', output='source')
else:
config.set(prefix='english', context='source', examples='source', output='english')
if task == NLI:
if model_type == ModelType.English.value:
config.set(prefix='source', context='source', examples='source', output='source')
elif language_type == LanguageType.High:
print("here")
config.set(prefix='english', context='source', examples='english')
else:
print("here1")
config.set(prefix='english', context='english', examples='english')
if task == SUMMARIZATION:
config.set(context='english')
return config.to_dict()