machine-translation / llm_toolkit /translation_utils.py
dh-mc's picture
enable do_sample
a69b127
raw
history blame
11.3 kB
import os
import re
import pandas as pd
import evaluate
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from tqdm import tqdm
from eval_modules.calc_repetitions import *
from llm_toolkit.llm_utils import load_tokenizer
print(f"loading {__file__}")
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
accuracy = evaluate.load("accuracy")
def extract_answer(text, debug=False):
if text:
# Remove the begin and end tokens
text = re.sub(
r".*?(assistant|\[/INST\]).+?\b", "", text, flags=re.DOTALL | re.MULTILINE
)
if debug:
print("--------\nstep 1:", text)
text = re.sub(r"<.+?>.*", "", text, flags=re.DOTALL | re.MULTILINE)
if debug:
print("--------\nstep 2:", text)
text = re.sub(
r".*?end_header_id\|>\n\n", "", text, flags=re.DOTALL | re.MULTILINE
)
if debug:
print("--------\nstep 3:", text)
return text
def calc_metrics(references, predictions, debug=False):
assert len(references) == len(
predictions
), f"lengths are difference: {len(references)} != {len(predictions)}"
predictions = [extract_answer(text) for text in predictions]
results = {}
results["meteor"] = meteor.compute(predictions=predictions, references=references)[
"meteor"
]
results["bleu_scores"] = bleu.compute(
predictions=predictions, references=references, max_order=4
)
results["rouge_scores"] = rouge.compute(
predictions=predictions, references=references
)
correct = [1 if ref == pred else 0 for ref, pred in zip(references, predictions)]
accuracy = sum(correct) / len(references)
results["accuracy"] = accuracy
if debug:
correct_ids = [i for i, c in enumerate(correct) if c == 1]
results["correct_ids"] = correct_ids
return results
def save_results(model_name, results_path, dataset, predictions, debug=False):
if not os.path.exists(results_path):
# Get the directory part of the file path
dir_path = os.path.dirname(results_path)
# Create all directories in the path (if they don't exist)
os.makedirs(dir_path, exist_ok=True)
df = dataset.to_pandas()
df.drop(columns=["text", "prompt"], inplace=True)
else:
df = pd.read_csv(results_path, on_bad_lines="warn")
df[model_name] = predictions
if debug:
print(df.head(1))
df.to_csv(results_path, index=False)
def load_translation_dataset(data_path, tokenizer=None):
train_data_file = data_path.replace(".tsv", "-train.tsv")
test_data_file = data_path.replace(".tsv", "-test.tsv")
if not os.path.exists(train_data_file):
print("generating train/test data files")
dataset = load_dataset(
"csv", data_files=data_path, delimiter="\t", split="train"
)
print(len(dataset))
dataset = dataset.filter(lambda x: x["chinese"] and x["english"])
datasets = dataset.train_test_split(test_size=0.2)
print(len(dataset))
# Convert to pandas DataFrame
train_df = pd.DataFrame(datasets["train"])
test_df = pd.DataFrame(datasets["test"])
# Save to TSV
train_df.to_csv(train_data_file, sep="\t", index=False)
test_df.to_csv(test_data_file, sep="\t", index=False)
print("loading train/test data files")
datasets = load_dataset(
"csv",
data_files={"train": train_data_file, "test": test_data_file},
delimiter="\t",
)
if tokenizer:
translation_prompt = "Please translate the following Chinese text into English and provide only the translated content, nothing else.\n{}"
def formatting_prompts_func(examples):
inputs = examples["chinese"]
outputs = examples["english"]
messages = [
{
"role": "system",
"content": "You are an expert in translating Chinese to English.",
},
None,
]
model_name = os.getenv("MODEL_NAME")
if "mistral" in model_name.lower():
messages = messages[1:]
texts = []
prompts = []
for input, output in zip(inputs, outputs):
prompt = translation_prompt.format(input)
messages[-1] = {"role": "user", "content": prompt}
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompts.append(prompt)
texts.append(prompt + output + tokenizer.eos_token)
return {"text": texts, "prompt": prompts}
datasets = datasets.map(
formatting_prompts_func,
batched=True,
)
print(datasets)
return datasets
def get_metrics(df, max_output_tokens=2048):
metrics_df = pd.DataFrame(df.columns.T)[2:]
metrics_df.rename(columns={0: "model"}, inplace=True)
metrics_df["rpp"] = metrics_df["model"].apply(lambda x: x.split("rpp-")[-1])
metrics_df["model"] = metrics_df["model"].apply(lambda x: x.split("/rpp-")[0])
metrics_df.reset_index(inplace=True)
metrics_df = metrics_df.drop(columns=["index"])
tokenizers = {
model: load_tokenizer(model) for model in metrics_df["model"].unique()
}
meteor = []
bleu_1 = []
rouge_l = []
ews_score = []
repetition_score = []
total_repetitions = []
num_entries_with_max_output_tokens = []
for col in df.columns[2:]:
metrics = calc_metrics(df["english"], df[col], debug=True)
print(f"{col}: {metrics}")
meteor.append(metrics["meteor"])
bleu_1.append(metrics["bleu_scores"]["bleu"])
rouge_l.append(metrics["rouge_scores"]["rougeL"])
df[["ews_score", "repetition_score", "total_repetitions"]] = df[col].apply(
detect_scores
)
ews_score.append(df["ews_score"].mean())
repetition_score.append(df["repetition_score"].mean())
total_repetitions.append(df["total_repetitions"].mean())
df["output_tokens"] = df[col].apply(
lambda x: len(tokenizers[col.split("/rpp")[0]](x)["input_ids"])
)
num_entries_with_max_output_tokens.append(
df["output_tokens"].value_counts().get(max_output_tokens, 0)
)
metrics_df["meteor"] = meteor
metrics_df["bleu_1"] = bleu_1
metrics_df["rouge_l"] = rouge_l
metrics_df["ews_score"] = ews_score
metrics_df["repetition_score"] = ews_score
metrics_df["total_repetitions"] = ews_score
metrics_df["num_entries_with_max_output_tokens"] = (
num_entries_with_max_output_tokens
)
return metrics_df
def plot_metrics(metrics_df, figsize=(14, 5), ylim=(0, 0.44)):
plt.figure(figsize=figsize)
df_melted = pd.melt(
metrics_df, id_vars="model", value_vars=["meteor", "bleu_1", "rouge_l"]
)
barplot = sns.barplot(x="variable", y="value", hue="model", data=df_melted)
# Set different hatches for each model
hatches = ["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*", "//", "\\\\"]
# Create a dictionary to map models to hatches
model_hatches = {
model: hatches[i % len(hatches)]
for i, model in enumerate(metrics_df["model"].unique())
}
# Apply hatches based on the model
num_vars = len(df_melted["variable"].unique())
for i, bar in enumerate(barplot.patches):
model = df_melted["model"].iloc[i // num_vars]
bar.set_hatch(model_hatches[model])
# Manually update legend to match the bar hatches
handles, labels = barplot.get_legend_handles_labels()
for handle, model in zip(handles, metrics_df["model"].unique()):
handle.set_hatch(model_hatches[model])
barplot.set_xticklabels(["METEOR", "BLEU-1", "ROUGE-L"])
for p in barplot.patches:
if p.get_height() == 0:
continue
barplot.annotate(
f"{p.get_height():.2f}",
(p.get_x() + p.get_width() / 2.0, p.get_height()),
ha="center",
va="center",
xytext=(0, 10),
textcoords="offset points",
)
barplot.set(ylim=ylim, ylabel="Scores", xlabel="Metrics")
plt.legend(bbox_to_anchor=(0.5, -0.1), loc="upper center")
plt.show()
def plot_times(perf_df, ylim=0.421):
# Adjusted code to put "train-time" bars in red at the bottom
fig, ax1 = plt.subplots(figsize=(12, 10))
color_train = "tab:red"
color_eval = "orange"
ax1.set_xlabel("Models")
ax1.set_ylabel("Time (mins)")
ax1.set_xticks(range(len(perf_df["model"]))) # Set x-ticks positions
ax1.set_xticklabels(perf_df["model"], rotation=90)
# Plot "train-time" first so it's at the bottom
ax1.bar(
perf_df["model"],
perf_df["train-time(mins)"],
color=color_train,
label="train-time",
)
# Then, plot "eval-time" on top of "train-time"
ax1.bar(
perf_df["model"],
perf_df["eval-time(mins)"],
bottom=perf_df["train-time(mins)"],
color=color_eval,
label="eval-time",
)
ax1.tick_params(axis="y")
ax1.legend(loc="upper left")
if "meteor" in perf_df.columns:
ax2 = ax1.twinx()
color_meteor = "tab:blue"
ax2.set_ylabel("METEOR", color=color_meteor)
ax2.plot(
perf_df["model"],
perf_df["meteor"],
color=color_meteor,
marker="o",
label="meteor",
)
ax2.tick_params(axis="y", labelcolor=color_meteor)
ax2.legend(loc="upper right")
ax2.set_ylim(ax2.get_ylim()[0], ylim)
# Show numbers in bars
for p in ax1.patches:
height = p.get_height()
if height == 0: # Skip bars with height 0
continue
ax1.annotate(
f"{height:.2f}",
(p.get_x() + p.get_width() / 2.0, p.get_y() + height),
ha="center",
va="center",
xytext=(0, -10),
textcoords="offset points",
)
fig.tight_layout()
plt.show()
def translate_via_llm(text):
base_url = os.getenv("OPENAI_BASE_URL") or "http://localhost:8000/v1"
llm = ChatOpenAI(
model="gpt-4o",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
base_url=base_url,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"Please translate the following Chinese text into English and provide only the translated content, nothing else.\n{input}",
),
]
)
chain = prompt | llm
response = chain.invoke(
{
"input": text,
}
)
return response.content
def translate(text, cache_dict):
if text in cache_dict:
return cache_dict[text]
else:
translated_text = translate_via_llm(text)
cache_dict[text] = translated_text
return translated_text