Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import json | |
import glob | |
import time | |
import yaml | |
import joblib | |
import argparse | |
import jinja2 | |
import anthropic | |
import pandas as pd | |
from tqdm import tqdm | |
from pathlib import Path | |
from loguru import logger | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
from data import get_leads | |
from utils import parse_json_garbage, compose_query | |
tqdm.pandas() | |
try: | |
logger.remove(0) | |
logger.add(sys.stderr, level="INFO") | |
except ValueError: | |
pass | |
load_dotenv() | |
def prepare_batch( crawled_result_path: str, config: dict, output_path: str, topn: int = None): | |
""" | |
Argument | |
-------- | |
crawled_result_path: str | |
Path to the crawled result file (result from the crawl task) | |
config: dict | |
Configuration for the batch job | |
output_path: str | |
Path to the output file | |
Return | |
------ | |
items: list | |
Example | |
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
model = model, | |
response_format = {"type": "json_object"}, | |
temperature = 0, | |
max_tokens = 4096, | |
""" | |
assert os.path.exists(crawled_result_path), f"File not found: {crawled_result_path}" | |
crawled_results = joblib.load(open(crawled_result_path, "rb"))['crawled_results'] | |
if topn: | |
crawled_results = crawled_results.head(topn) | |
jenv = jinja2.Environment() | |
template = jenv.from_string(config['extraction_prompt']) | |
system_prompt = template.render( classes = config['classes'], traits = config['traits']) | |
template = jenv.from_string(config['user_content']) | |
items = [] | |
for i, d in tqdm(enumerate(crawled_results.itertuples())): | |
idx = d.index # d[1] | |
evidence = d.googlemap_results +"\n" + d.search_results | |
business_id = d.business_id # d[2] | |
business_name = d.business_name # d[3] | |
address = d.address # d[7] | |
ana_res = None | |
query = compose_query( address, business_name, use_exclude=False) | |
user_content = template.render( query = query, search_results = evidence) | |
item = { | |
"custom_id": str(business_id), | |
"method": "POST", | |
"url": "/v1/chat/completions", | |
"body": { | |
"model": config['model'], | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_content} | |
], | |
"max_tokens": config['max_tokens'], | |
"temperature": config['temperature'], | |
"response_format": {"type": "json_object"}, | |
} | |
} | |
items.append( json.dumps(item, ensure_ascii=False)) | |
with open(output_path, "w") as f: | |
for item in items: | |
f.write(item + "\n") | |
def prepare_regularization( extracted_result_path: str, config: dict, output_path: str, topn: int = None): | |
""" | |
Argument | |
-------- | |
extracted_file_path: str | |
Path to the extracted result file (result from the extraction task) | |
config: dict | |
Configuration for the batch job | |
output_path: str | |
Path to the output file | |
topn: int | |
Number of records to be processed | |
Return | |
------ | |
items: list | |
Example | |
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} | |
model = model, | |
response_format = {"type": "json_object"}, | |
temperature = 0, | |
max_tokens = 4096, | |
""" | |
assert os.path.exists(extracted_result_path), f"File not found: {extracted_result_path}" | |
extracted_results = joblib.load(open(extracted_result_path, "rb"))['extracted_results'] | |
if topn: | |
extracted_results = extracted_results.head(topn) | |
jenv = jinja2.Environment() | |
template = jenv.from_string(config['regularization_prompt']) | |
system_prompt = template.render() | |
template = jenv.from_string(config['regularization_user_content']) | |
items = [] | |
for i, d in tqdm(enumerate(extracted_results.itertuples())): | |
idx = d.index # d[1] | |
category = d.category | |
business_id = d.business_id | |
if pd.isna(category) or len(category)==0: | |
category = "" | |
user_content = template.render( category = category) | |
item = { | |
"custom_id": str(business_id), | |
"method": "POST", | |
"url": "/v1/chat/completions", | |
"body": { | |
"model": config['model'], | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_content} | |
], | |
"max_tokens": config['max_tokens'], | |
"temperature": config['temperature'], | |
"response_format": {"type": "json_object"}, | |
} | |
} | |
items.append( json.dumps(item, ensure_ascii=False)) | |
with open(output_path, "w") as f: | |
for item in items: | |
f.write(item + "\n") | |
def run_batch( input_path: str, job_path: str, jsonl_path: str): | |
""" | |
Argument | |
-------- | |
input_path: str | |
Path to the prepared batch input file (result from prepare_batch) | |
job_path: str | |
Path to the job file (response from creating a batch job) | |
jsonl_path: str | |
Path to the output file | |
extracted_result_path: str | |
Path to the extracted result file | |
""" | |
assert os.path.exists(input_path), f"File not found: {input_path}" | |
st = time.time() | |
client = OpenAI( organization = os.getenv('ORGANIZATION_ID')) | |
batch_input_file = client.files.create( | |
file=open( input_path, "rb"), | |
purpose="batch" | |
) | |
batch_input_file_id = batch_input_file.id | |
logger.info(f"batch_input_file_id -> {batch_input_file_id}") | |
batch_resp = client.batches.create( | |
input_file_id=batch_input_file_id, | |
endpoint="/v1/chat/completions", | |
completion_window="24h", | |
metadata={ | |
"description": "batch job" | |
} | |
) | |
logger.info(f"batch resp -> {batch_resp}") | |
try: | |
with open( job_path, "wb") as f: | |
joblib.dump(batch_resp, f) | |
except Exception as e: | |
logger.error(f"Error -> {e}") | |
with open("./job.joblib", "wb") as f: | |
joblib.dump(batch_resp, f) | |
is_ready = False | |
while 1: | |
batch_resp = client.batches.retrieve(batch_resp.id) | |
if batch_resp.status == 'validating': | |
logger.info("the input file is being validated before the batch can begin") | |
elif batch_resp.status == 'failed': | |
logger.info("the input file has failed the validation process") | |
break | |
elif batch_resp.status == 'in_progress': | |
logger.info("the input file was successfully validated and the batch is currently being ru") | |
elif batch_resp.status == 'finalizing': | |
logger.info("the batch has completed and the results are being prepared") | |
elif batch_resp.status == 'completed': | |
logger.info("the batch has been completed and the results are ready") | |
is_ready = True | |
break | |
elif batch_resp.status == 'expired': | |
logger.info("the batch was not able to be completed within the 24-hour time window") | |
break | |
elif batch_resp.status == 'cancelling': | |
logger.info("the batch is being cancelled (may take up to 10 minutes)") | |
elif batch_resp.status == 'cancelled': | |
logger.info("the batch was cancelled") | |
break | |
else: | |
raise logger.error("Invalid status") | |
time.sleep(10) | |
if is_ready: | |
output_resp = client.files.content(batch_resp.output_file_id) | |
llm_results = [] | |
try: | |
with open(jsonl_path, "w") as f: | |
for line in output_resp.content.decode('utf-8').split("\n"): | |
line = line.strip() | |
if len(line)==0: | |
break | |
llm_results.append(line) | |
f.write(f"{line}\n") | |
except Exception as e: | |
logger.error(f"Error -> {e}") | |
with open("./output.jsonl", "w") as f: | |
for line in output_resp.content.decode('utf-8').split("\n"): | |
line = line.strip() | |
if len(line)==0: | |
break | |
llm_results.append(line) | |
f.write(f"{line}\n") | |
print( f"Time elapsed: {time.time()-st:.2f} seconds") | |
def batch2extract( jsonl_path: str, crawled_result_path: str, extracted_result_path: str): | |
""" | |
Argument | |
-------- | |
jsonl_path: str | |
Path to the batch output file | |
crawled_result_path: str | |
Path to the crawled result file (result from the crawl task) | |
extracted_result_path: str | |
Path to the extracted result file | |
""" | |
assert os.path.exists(jsonl_path), f"File not found: {jsonl_path}" | |
assert os.path.exists(crawled_result_path), f"File not found: {crawled_result_path}" | |
crawled_results = joblib.load(open(crawled_result_path, "rb")) | |
extracted_results = [] | |
empty_indices = [] | |
llm_results = [] | |
for line in open(jsonl_path, "r"): | |
line = line.strip() | |
if len(line)==0: | |
break | |
llm_results.append(line) | |
for i,llm_result in enumerate(llm_results): | |
try: | |
llm_result = json.loads(llm_result) | |
business_id = llm_result['custom_id'] | |
llm_result = llm_result['response']['body']['choices'][0]['message']['content'] | |
llm_result = parse_json_garbage(llm_result) | |
llm_result['business_id'] = business_id | |
extracted_results.append(llm_result) | |
except Exception as e: | |
logger.error(f"Error -> {e}, llm_result -> {llm_result}") | |
empty_indices.append(i) | |
extracted_results = pd.DataFrame(extracted_results) | |
basic_info = [] | |
for i, d in tqdm(enumerate(crawled_results['crawled_results'].itertuples())): | |
idx = d.index # d[1] | |
evidence = d.googlemap_results +"\n" + d.search_results | |
business_id = d.business_id # d[2] | |
business_name = d.business_name # d[3] | |
address = d.address # d[7] | |
# ana_res = None | |
# query = compose_query( address, business_name, use_exclude=False) | |
basic_info.append( { | |
"index": idx, | |
"business_id": business_id, | |
"business_name": business_name, | |
"evidence": evidence, | |
# ** ext_res | |
} ) | |
basic_info = pd.DataFrame(basic_info) | |
extracted_results = basic_info.astype({"business_id": str}).merge(extracted_results, on="business_id", how="inner") | |
print( f"{ extracted_results.shape[0]} records merged.") | |
extracted_results = {"extracted_results": extracted_results, "empty_indices": empty_indices} | |
with open(extracted_result_path, "wb") as f: | |
joblib.dump(extracted_results, f) | |
def batch2reg( jsonl_path: str, extracted_result_path: str, regularized_result_path: str): | |
""" | |
Argument | |
-------- | |
jsonl_path: str | |
Path to the batch output file | |
extracted_result_path: str | |
Path to the extracted result file | |
regularized_result_path: str | |
Path to the regularization result file | |
""" | |
assert os.path.exists(jsonl_path), f"File not found: {jsonl_path}" | |
assert os.path.exists(extracted_result_path), f"File not found: {extracted_result_path}" | |
extracted_results = joblib.load(open(extracted_result_path, "rb"))['extracted_results'] | |
llm_results, regularized_results, empty_indices = [], [], [] | |
for line in open(jsonl_path, "r"): | |
line = line.strip() | |
if len(line)==0: | |
break | |
llm_results.append(line) | |
for i,llm_result in enumerate(llm_results): | |
try: | |
llm_result = json.loads(llm_result) | |
business_id = llm_result['custom_id'] | |
llm_result = llm_result['response']['body']['choices'][0]['message']['content'] | |
llm_result = parse_json_garbage(llm_result) | |
llm_result['business_id'] = business_id | |
regularized_results.append(llm_result) | |
except Exception as e: | |
logger.error(f"Error -> {e}, llm_result -> {llm_result}") | |
empty_indices.append(i) | |
regularized_results = pd.DataFrame(regularized_results) | |
basic_info = [] | |
for i, d in tqdm(enumerate(extracted_results.itertuples())): | |
idx = d.index # d[1] | |
# evidence = d.googlemap_results +"\n" + d.search_results | |
evidence = d.evidence | |
business_id = d.business_id # d[2] | |
business_name = d.business_name # d[3] | |
# address = d.address # d[7] | |
# ana_res = None | |
# query = compose_query( address, business_name, use_exclude=False) | |
basic_info.append( { | |
"index": idx, | |
"business_id": business_id, | |
"business_name": business_name, | |
"evidence": evidence, | |
# ** ext_res | |
} ) | |
basic_info = pd.DataFrame(basic_info) | |
regularized_results = basic_info.astype({"business_id": str}).merge(regularized_results, on="business_id", how="inner") | |
print( f"{ regularized_results.shape[0]} records merged.") | |
regularized_results = {"regularized_results": regularized_results, "empty_indices": empty_indices} | |
with open(regularized_result_path, "wb") as f: | |
joblib.dump(regularized_results, f) | |
def postprocess_result( config: dict, regularized_result_path: str, postprocessed_result_path, category_hierarchy: dict, column_name: str = 'category') -> pd.DataFrame: | |
""" | |
Argument | |
config: dict | |
regularized_results_path: str | |
analysis_result: `evidence`, `result` | |
postprocessed_results_path | |
Return | |
""" | |
assert os.path.exists(regularized_result_path), f"File not found: {regularized_result_path}" | |
regularized_results = joblib.load(open(regularized_result_path, "rb"))['regularized_results'] | |
if True: | |
# if not os.path.exists(postprocessed_result_path): | |
postprocessed_results = regularized_results.copy() | |
postprocessed_results.loc[ :, "category"] = postprocessed_results[column_name].progress_apply(lambda x: "" if x not in category_hierarchy else x) | |
postprocessed_results['supercategory'] = postprocessed_results[column_name].progress_apply(lambda x: category_hierarchy.get(x, '')) | |
# with open( postprocessed_results_path, "wb") as f: | |
# joblib.dump( postprocessed_results, f) | |
postprocessed_results.to_csv( postprocessed_result_path, index=False) | |
else: | |
# with open( postprocessed_results_path, "rb") as f: | |
# postprocessed_results = joblib.load(f) | |
postprocessed_results = pd.read_csv( postprocessed_result_path) | |
return postprocessed_results | |
def combine_postprocessed_results( config: dict, input_path: str, postprocessed_result_path: str, reference_path: str, output_path: str): | |
""" | |
Argument | |
config: dict | |
input_path: str | |
postprocessed_result_path: str | |
reference_path: str | |
output_path: str | |
""" | |
file_pattern = str(Path(input_path).joinpath( postprocessed_result_path, "postprocessed_results.csv")) | |
logger.info(f"file_pattern -> {file_pattern}") | |
file_paths = list(glob.glob(file_pattern)) | |
assert len(file_paths)>0, f"File not found: {postprocessed_result_path}" | |
postprocessed_results = pd.concat([pd.read_csv(file_path, dtype={"business_id": str}) for file_path in file_paths], axis=0) | |
reference_results = get_leads( reference_path) | |
# reference_results = reference_results.rename(config['column_mapping'], axis=1) | |
postprocessed_results = reference_results.merge( postprocessed_results, left_on = "統一編號", right_on="business_id", how="left") | |
postprocessed_results.to_csv( output_path, index=False) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( "-c", "--config", type=str, default='config/config.yml', help="Path to the configuration file") | |
parser.add_argument( "-t", "--task", type=str, default='prepare_batch', choices=['prepare_batch', 'prepare_regularization', 'run_batch', 'batch2extract', 'batch2reg', 'postprocess', 'combine']) | |
parser.add_argument( "-i", "--input_path", type=str, default='', ) | |
parser.add_argument( "-o", "--output_path", type=str, default='', ) | |
parser.add_argument( "-b", "--batch_path", type=str, default='', ) | |
parser.add_argument( "-j", "--job_path", type=str, default='', ) | |
parser.add_argument( "-jp", "--jsonl_path", type=str, default='', ) | |
parser.add_argument( "-crp", "--crawled_result_path", type=str, default='', ) | |
parser.add_argument( "-erp", "--extracted_result_path", type=str, default='', ) | |
parser.add_argument( "-rrp", "--regularized_result_path", type=str, default='', ) | |
parser.add_argument( "-prp", "--postprocessed_result_path", type=str, default='', ) | |
parser.add_argument( "-rp", "--reference_path", type=str, default='', ) | |
parser.add_argument( "-topn", "--topn", type=int, default=None ) | |
args = parser.parse_args() | |
# classes = ['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', ] | |
# backup_classes = [ '中式', '西式'] | |
assert os.path.exists(args.config), f"File not found: {args.config}" | |
config = yaml.safe_load(open(args.config, "r").read()) | |
if args.task == 'prepare_batch': | |
prepare_batch( crawled_result_path = args.crawled_result_path, config = config, output_path = args.output_path, topn = args.topn) | |
elif args.task == 'run_batch': | |
run_batch( input_path = args.input_path, job_path = args.job_path, jsonl_path = args.jsonl_path) | |
elif args.task == 'prepare_regularization': | |
prepare_regularization( extracted_result_path = args.extracted_result_path, config = config, output_path = args.output_path, topn = args.topn) | |
elif args.task == 'batch2extract': | |
batch2extract( | |
jsonl_path = args.jsonl_path, | |
crawled_result_path = args.crawled_result_path, | |
extracted_result_path = args.extracted_result_path | |
) | |
elif args.task == 'batch2reg': | |
batch2reg( | |
jsonl_path = args.jsonl_path, | |
extracted_result_path = args.extracted_result_path, | |
regularized_result_path = args.regularized_result_path | |
) | |
elif args.task == 'postprocess': | |
postprocess_result( | |
config = config, | |
regularized_result_path = args.regularized_result_path, | |
postprocessed_result_path = args.postprocessed_result_path, | |
category_hierarchy = config['category2supercategory'], | |
column_name = 'category' | |
) | |
elif args.task == 'combine': | |
combine_postprocessed_results( | |
config, | |
args.input_path, | |
args.postprocessed_result_path, | |
args.reference_path, | |
args.output_path | |
) | |
else: | |
raise Exception("Invalid task") | |