Spaces:
Build error
Build error
import json | |
import os | |
import re | |
import string | |
import zipfile | |
import httpx | |
def download_data(dir): | |
import gdown | |
data_path = os.path.join(dir, 'data/external_corpus') | |
if os.path.exists(data_path): | |
return data_path | |
url = 'https://drive.google.com/uc?id=1zRbHzPW2x4dDcfmphBWlan8cxUCRNmqk' | |
zip_path = os.path.join(dir, 'data.zip') | |
gdown.download(url, zip_path, quiet=False) | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(os.path.join(dir, 'data')) | |
if os.path.exists(zip_path): | |
os.remove(zip_path) | |
print(f'Data saved to {data_path}') | |
return data_path | |
def download_tools(dir, wolfram_alpha_appid='YOUR_WOLFRAMALPHA_APPID'): | |
tool_path = os.path.join(dir, 'tools') | |
if os.path.exists(tool_path): | |
return tool_path | |
os.mkdir(tool_path) | |
tools = [ | |
'code/sql_interpreter.py', | |
'graph/graphtools.py', | |
'math/calculator.py', | |
'table/mysql_db_create.py', | |
'table/tabtools.py', | |
'text/agenda_retriever.py', | |
'text/scirex_retriever.py', | |
] | |
for tool in tools: | |
url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/benchmark/ReAct/code/tools/{tool}' | |
response = httpx.get(url) | |
output_file = os.path.join(tool_path, tool.split('/')[1]) | |
with open(output_file, 'wb') as f: | |
f.write(response.content) | |
print(f'Tool saved to {output_file}') | |
with open(os.path.join(tool_path, 'calculator.py'), 'r') as f: | |
content = f.read() | |
new_content = content.replace('YOUR_WOLFRAMALPHA_APPID', wolfram_alpha_appid) | |
with open(os.path.join(tool_path, 'calculator.py'), 'w') as f: | |
f.write(new_content) | |
with open(os.path.join(tool_path, 'agenda_retriever.py'), 'r') as f: | |
content = f.read() | |
new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '') | |
with open(os.path.join(tool_path, 'agenda_retriever.py'), 'w') as f: | |
f.write(new_content) | |
with open(os.path.join(tool_path, 'mysql_db_create.py'), 'r') as f: | |
content = f.read() | |
new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '') | |
with open(os.path.join(tool_path, 'mysql_db_create.py'), 'w') as f: | |
f.write(new_content) | |
with open(os.path.join(tool_path, 'scirex_retriever.py'), 'r') as f: | |
content = f.read() | |
new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '') | |
with open(os.path.join(tool_path, 'scirex_retriever.py'), 'w') as f: | |
f.write(new_content) | |
LOCAL_DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') | |
def get_data(dataset, hardness): | |
data_path = os.path.join(LOCAL_DATA_DIR, f'{dataset}-{hardness}.jsonl') | |
if os.path.exists(data_path): | |
print(f'Loading data from {data_path}') | |
with open(data_path, 'r') as f: | |
return json.load(f) | |
else: | |
print( | |
f'Downloading data from https://raw.githubusercontent.com/night-chen/ToolQA/main/data/questions/{hardness}/{dataset}-{hardness}.jsonl' | |
) | |
data = [] | |
url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/data/questions/{hardness}/{dataset}-{hardness}.jsonl' | |
url = httpx.get(url) | |
if url.status_code == 200: | |
lines = url.text.splitlines() | |
for line in lines: | |
data.append(json.loads(line)) | |
with open(data_path, 'w') as f: | |
json.dump(data, f) | |
print(f'Data saved to {data_path}') | |
return data | |
REACT_INSTRUCTION = """Use tools in the tools directory to solve the task: {question} | |
You could use all tools which are under the tools/ directory and all the data under the data/ directory. | |
When you think you finished the task, respond with `Finish[answer]` where you include your answer in `[]`. | |
IMPORTANT: Make sure that in your final answer, you should not print any additional text/instructions other than the actual answer, which should be a word or a simple phrase. | |
""" | |
def encode_question(question): | |
return REACT_INSTRUCTION.format(question=question) | |
# imported from https://github.com/night-chen/ToolQA/tree/main/benchmark/ReAct/code/agents_chatgpt.py | |
def normalize_answer(s): | |
def remove_articles(text): | |
return re.sub(r'\b(a|an|the|usd)\b', ' ', text) | |
def white_space_fix(text): | |
return ' '.join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return ''.join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def eval_answer(pred, answer): | |
pattern = r'Finish\[(.*?)\]' | |
match = re.search(pattern, pred) | |
if match: | |
pred = match.group(1) | |
return normalize_answer(pred) == normalize_answer(answer) | |