Spaces:
Runtime error
Runtime error
"""The program includes several functions: setting a random seed, | |
loading data from a JSON file, batching data, and extracting answers from generated text. | |
""" | |
import random | |
import numpy as np | |
import torch | |
import json | |
import re | |
def set_random_seed(seed: int): | |
""" | |
Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`. | |
Parameters | |
------------ | |
seed : int | |
The default seed. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
def load_data(file_name: str): | |
""" | |
Load data with file name. | |
Parameters | |
------------ | |
file_name : str. | |
The dataset file name. | |
Returns | |
------------ | |
inputs : list. | |
The input texts of the dataset. | |
outputs : list. | |
The output texts file datasets. | |
len : int. | |
The length of the dataset. | |
""" | |
inputs = [] | |
outputs = [] | |
type = "" | |
with open(file_name, encoding='utf-8') as f: | |
json_data = json.load(f) | |
type = json_data["type"] | |
for line in json_data["instances"]: | |
inputs.append(line["input"]) | |
outputs.append(line["output"]) | |
print(f"load dataset {file_name} success.\n") | |
print(f"Type : {type}, datasize : {len(outputs)}") | |
return inputs, outputs, len(outputs) | |
def batchlize(examples: list, batch_size: int, random_shuffle: bool): | |
""" | |
Convert examples to a dataloader. | |
Parameters | |
------------ | |
examples : list. | |
Data list. | |
batch_size : int. | |
random_shuffle : bool | |
If true, the dataloader shuffle the training data. | |
Returns | |
------------ | |
dataloader: | |
Dataloader with batch generator. | |
""" | |
size = 0 | |
dataloader = [] | |
length = len(examples) | |
if (random_shuffle): | |
random.shuffle(examples) | |
while size < length: | |
if length - size > batch_size: | |
dataloader.append(examples[size : size+batch_size]) | |
size += batch_size | |
else: | |
dataloader.append(examples[size : size+(length-size)]) | |
size += (length - size) | |
return dataloader | |
def answer_extraction(response, answer_type=None): #use this funtion to extract answers from generated text | |
""" | |
Use this funtion to extract answers from generated text | |
Parameters | |
------------ | |
args : | |
Arguments. | |
response : str | |
plain string response. | |
Returns | |
------------ | |
answer: | |
Decoded answer (such as A, B, C, D, E for mutiple-choice QA). | |
""" | |
# temp = response["generated_text"] | |
temp = response | |
if answer_type in ("gsm8k", "svamp", "asdiv", "addsub", "singleeq", "multiarith", "math"): | |
temp = temp.replace(",", "") | |
temp = [s for s in re.findall(r'-?\d+\.?\d*', temp)] | |
elif answer_type in ("aqua", "csqa", "multiple_choice"): | |
temp = re.findall(r'A|B|C|D|E', temp) | |
elif answer_type in ("strategyqa", "coin_flip"): | |
temp = temp.lower() | |
temp = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", temp) | |
temp = temp.split(" ") | |
temp = [i for i in temp if i in ("yes", "no")] | |
elif answer_type in ("last_letters"): | |
temp = re.sub("\"|\'|\n|\.|\s","", temp) | |
temp = [temp] | |
elif answer_type in ("pubmedqa", "binary_choice"): | |
# pattern = "Output: (yes|no|maybe)" | |
# sttr = re.search(pattern, temp) | |
# answer = sttr.group(0)[8:] if sttr is not None else "N/A" | |
pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)" | |
sttr = re.search(pattern, temp) | |
if sttr is not None: | |
mid_answer = sttr.group(0) | |
mid_answer = mid_answer.split(":")[-1].strip() | |
answer = mid_answer.lower() | |
else: | |
pattern = "(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)(\.|\s)" | |
sttr = re.search(pattern, temp) | |
if sttr is not None: | |
answer = sttr.group(0)[:-1].lower() | |
else: | |
answer = "N/A" | |
return answer | |
elif answer_type == "medmcqa": | |
# pattern = "Output: (A|B|C|D)." | |
# sttr = re.search(pattern, temp) | |
# answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A" | |
pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)" | |
sttr = re.search(pattern, temp) | |
if sttr is not None: | |
mid_answer = sttr.group(0) | |
answer = mid_answer[-1].lower() | |
else: | |
pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" | |
sttr = re.search(pattern, temp) | |
if sttr is not None: | |
if '(' in sttr.group(0): | |
answer = sttr.group(0)[1].lower() | |
else: | |
answer = sttr.group(0)[0].lower() | |
else: | |
answer = "N/A" | |
return answer | |
elif answer_type == "usmle": | |
# pattern = "Output: (A|B|C|D)." | |
# sttr = re.search(pattern, temp) | |
# answer = sttr.group(0)[8:-1].lower() if sttr is not None else "N/A" | |
pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)" | |
sttr = re.search(pattern, temp) | |
if sttr is not None: | |
mid_answer = sttr.group(0) | |
answer = mid_answer[-1].lower() | |
else: | |
pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" | |
sttr = re.search(pattern, temp) | |
if sttr is not None: | |
if '(' in sttr.group(0): | |
answer = sttr.group(0)[1].lower() | |
else: | |
answer = sttr.group(0)[0].lower() | |
else: | |
answer = "N/A" | |
return answer | |
elif answer_type == "text": | |
return response | |
else: | |
raise NotImplementedError(f"Unsupported answer type: {answer_type}") | |
if len(temp) != 0: | |
answer = temp[-1] | |
# if there is . at the end of answer, remove it | |
# e.g. answer = 64. | |
if answer != "": | |
if answer[-1] == ".": | |
answer = answer[:-1] | |
# round the answer to nearest integer | |
if answer_type in ("gsm8k", "svamp"): | |
try: | |
answer = str(round(float(answer))) | |
except: | |
answer = "" # no sol or sol doesn't have valid format | |
elif answer_type in ("last_letters"): | |
try: | |
answer = answer[-args.concat_length:] | |
except: | |
answer = "" | |
else: | |
answer = "" | |
return answer | |