import os import sys import json import glob import time import yaml import joblib import argparse import jinja2 import anthropic import pandas as pd from tqdm import tqdm from pathlib import Path from loguru import logger from openai import OpenAI from dotenv import load_dotenv import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold from data import get_leads from utils import parse_json_garbage, compose_query tqdm.pandas() try: logger.remove(0) logger.add(sys.stderr, level="INFO") except ValueError: pass load_dotenv() def prepare_batch( crawled_result_path: str, config: dict, output_path: str, topn: int = None): """ Argument -------- crawled_result_path: str Path to the crawled result file (result from the crawl task) config: dict Configuration for the batch job output_path: str Path to the output file Return ------ items: list Example {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} model = model, response_format = {"type": "json_object"}, temperature = 0, max_tokens = 4096, """ assert os.path.exists(crawled_result_path), f"File not found: {crawled_result_path}" crawled_results = joblib.load(open(crawled_result_path, "rb"))['crawled_results'] if topn: crawled_results = crawled_results.head(topn) jenv = jinja2.Environment() template = jenv.from_string(config['extraction_prompt']) system_prompt = template.render( classes = config['classes'], traits = config['traits']) template = jenv.from_string(config['user_content']) items = [] for i, d in tqdm(enumerate(crawled_results.itertuples())): idx = d.index # d[1] evidence = d.googlemap_results +"\n" + d.search_results business_id = d.business_id # d[2] business_name = d.business_name # d[3] address = d.address # d[7] ana_res = None query = compose_query( address, business_name, use_exclude=False) user_content = template.render( query = query, search_results = evidence) item = { "custom_id": str(business_id), "method": "POST", "url": "/v1/chat/completions", "body": { "model": config['model'], "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_content} ], "max_tokens": config['max_tokens'], "temperature": config['temperature'], "response_format": {"type": "json_object"}, } } items.append( json.dumps(item, ensure_ascii=False)) with open(output_path, "w") as f: for item in items: f.write(item + "\n") def prepare_regularization( extracted_result_path: str, config: dict, output_path: str, topn: int = None): """ Argument -------- extracted_file_path: str Path to the extracted result file (result from the extraction task) config: dict Configuration for the batch job output_path: str Path to the output file topn: int Number of records to be processed Return ------ items: list Example {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} model = model, response_format = {"type": "json_object"}, temperature = 0, max_tokens = 4096, """ assert os.path.exists(extracted_result_path), f"File not found: {extracted_result_path}" extracted_results = joblib.load(open(extracted_result_path, "rb"))['extracted_results'] if topn: extracted_results = extracted_results.head(topn) jenv = jinja2.Environment() template = jenv.from_string(config['regularization_prompt']) system_prompt = template.render() template = jenv.from_string(config['regularization_user_content']) items = [] for i, d in tqdm(enumerate(extracted_results.itertuples())): idx = d.index # d[1] category = d.category business_id = d.business_id if pd.isna(category) or len(category)==0: category = "" user_content = template.render( category = category) item = { "custom_id": str(business_id), "method": "POST", "url": "/v1/chat/completions", "body": { "model": config['model'], "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_content} ], "max_tokens": config['max_tokens'], "temperature": config['temperature'], "response_format": {"type": "json_object"}, } } items.append( json.dumps(item, ensure_ascii=False)) with open(output_path, "w") as f: for item in items: f.write(item + "\n") def run_batch( input_path: str, job_path: str, jsonl_path: str): """ Argument -------- input_path: str Path to the prepared batch input file (result from prepare_batch) job_path: str Path to the job file (response from creating a batch job) jsonl_path: str Path to the output file extracted_result_path: str Path to the extracted result file """ assert os.path.exists(input_path), f"File not found: {input_path}" st = time.time() client = OpenAI( organization = os.getenv('ORGANIZATION_ID')) batch_input_file = client.files.create( file=open( input_path, "rb"), purpose="batch" ) batch_input_file_id = batch_input_file.id logger.info(f"batch_input_file_id -> {batch_input_file_id}") batch_resp = client.batches.create( input_file_id=batch_input_file_id, endpoint="/v1/chat/completions", completion_window="24h", metadata={ "description": "batch job" } ) logger.info(f"batch resp -> {batch_resp}") try: with open( job_path, "wb") as f: joblib.dump(batch_resp, f) except Exception as e: logger.error(f"Error -> {e}") with open("./job.joblib", "wb") as f: joblib.dump(batch_resp, f) is_ready = False while 1: batch_resp = client.batches.retrieve(batch_resp.id) if batch_resp.status == 'validating': logger.info("the input file is being validated before the batch can begin") elif batch_resp.status == 'failed': logger.info("the input file has failed the validation process") break elif batch_resp.status == 'in_progress': logger.info("the input file was successfully validated and the batch is currently being ru") elif batch_resp.status == 'finalizing': logger.info("the batch has completed and the results are being prepared") elif batch_resp.status == 'completed': logger.info("the batch has been completed and the results are ready") is_ready = True break elif batch_resp.status == 'expired': logger.info("the batch was not able to be completed within the 24-hour time window") break elif batch_resp.status == 'cancelling': logger.info("the batch is being cancelled (may take up to 10 minutes)") elif batch_resp.status == 'cancelled': logger.info("the batch was cancelled") break else: raise logger.error("Invalid status") time.sleep(10) if is_ready: output_resp = client.files.content(batch_resp.output_file_id) llm_results = [] try: with open(jsonl_path, "w") as f: for line in output_resp.content.decode('utf-8').split("\n"): line = line.strip() if len(line)==0: break llm_results.append(line) f.write(f"{line}\n") except Exception as e: logger.error(f"Error -> {e}") with open("./output.jsonl", "w") as f: for line in output_resp.content.decode('utf-8').split("\n"): line = line.strip() if len(line)==0: break llm_results.append(line) f.write(f"{line}\n") print( f"Time elapsed: {time.time()-st:.2f} seconds") def batch2extract( jsonl_path: str, crawled_result_path: str, extracted_result_path: str): """ Argument -------- jsonl_path: str Path to the batch output file crawled_result_path: str Path to the crawled result file (result from the crawl task) extracted_result_path: str Path to the extracted result file """ assert os.path.exists(jsonl_path), f"File not found: {jsonl_path}" assert os.path.exists(crawled_result_path), f"File not found: {crawled_result_path}" crawled_results = joblib.load(open(crawled_result_path, "rb")) extracted_results = [] empty_indices = [] llm_results = [] for line in open(jsonl_path, "r"): line = line.strip() if len(line)==0: break llm_results.append(line) for i,llm_result in enumerate(llm_results): try: llm_result = json.loads(llm_result) business_id = llm_result['custom_id'] llm_result = llm_result['response']['body']['choices'][0]['message']['content'] llm_result = parse_json_garbage(llm_result) llm_result['business_id'] = business_id extracted_results.append(llm_result) except Exception as e: logger.error(f"Error -> {e}, llm_result -> {llm_result}") empty_indices.append(i) extracted_results = pd.DataFrame(extracted_results) basic_info = [] for i, d in tqdm(enumerate(crawled_results['crawled_results'].itertuples())): idx = d.index # d[1] evidence = d.googlemap_results +"\n" + d.search_results business_id = d.business_id # d[2] business_name = d.business_name # d[3] address = d.address # d[7] # ana_res = None # query = compose_query( address, business_name, use_exclude=False) basic_info.append( { "index": idx, "business_id": business_id, "business_name": business_name, "evidence": evidence, # ** ext_res } ) basic_info = pd.DataFrame(basic_info) extracted_results = basic_info.astype({"business_id": str}).merge(extracted_results, on="business_id", how="inner") print( f"{ extracted_results.shape[0]} records merged.") extracted_results = {"extracted_results": extracted_results, "empty_indices": empty_indices} with open(extracted_result_path, "wb") as f: joblib.dump(extracted_results, f) def batch2reg( jsonl_path: str, extracted_result_path: str, regularized_result_path: str): """ Argument -------- jsonl_path: str Path to the batch output file extracted_result_path: str Path to the extracted result file regularized_result_path: str Path to the regularization result file """ assert os.path.exists(jsonl_path), f"File not found: {jsonl_path}" assert os.path.exists(extracted_result_path), f"File not found: {extracted_result_path}" extracted_results = joblib.load(open(extracted_result_path, "rb"))['extracted_results'] llm_results, regularized_results, empty_indices = [], [], [] for line in open(jsonl_path, "r"): line = line.strip() if len(line)==0: break llm_results.append(line) for i,llm_result in enumerate(llm_results): try: llm_result = json.loads(llm_result) business_id = llm_result['custom_id'] llm_result = llm_result['response']['body']['choices'][0]['message']['content'] llm_result = parse_json_garbage(llm_result) llm_result['business_id'] = business_id regularized_results.append(llm_result) except Exception as e: logger.error(f"Error -> {e}, llm_result -> {llm_result}") empty_indices.append(i) regularized_results = pd.DataFrame(regularized_results) basic_info = [] for i, d in tqdm(enumerate(extracted_results.itertuples())): idx = d.index # d[1] # evidence = d.googlemap_results +"\n" + d.search_results evidence = d.evidence business_id = d.business_id # d[2] business_name = d.business_name # d[3] # address = d.address # d[7] # ana_res = None # query = compose_query( address, business_name, use_exclude=False) basic_info.append( { "index": idx, "business_id": business_id, "business_name": business_name, "evidence": evidence, # ** ext_res } ) basic_info = pd.DataFrame(basic_info) regularized_results = basic_info.astype({"business_id": str}).merge(regularized_results, on="business_id", how="inner") print( f"{ regularized_results.shape[0]} records merged.") regularized_results = {"regularized_results": regularized_results, "empty_indices": empty_indices} with open(regularized_result_path, "wb") as f: joblib.dump(regularized_results, f) def postprocess_result( config: dict, regularized_result_path: str, postprocessed_result_path, category_hierarchy: dict, column_name: str = 'category') -> pd.DataFrame: """ Argument config: dict regularized_results_path: str analysis_result: `evidence`, `result` postprocessed_results_path Return """ assert os.path.exists(regularized_result_path), f"File not found: {regularized_result_path}" regularized_results = joblib.load(open(regularized_result_path, "rb"))['regularized_results'] if True: # if not os.path.exists(postprocessed_result_path): postprocessed_results = regularized_results.copy() postprocessed_results.loc[ :, "category"] = postprocessed_results[column_name].progress_apply(lambda x: "" if x not in category_hierarchy else x) postprocessed_results['supercategory'] = postprocessed_results[column_name].progress_apply(lambda x: category_hierarchy.get(x, '')) # with open( postprocessed_results_path, "wb") as f: # joblib.dump( postprocessed_results, f) postprocessed_results.to_csv( postprocessed_result_path, index=False) else: # with open( postprocessed_results_path, "rb") as f: # postprocessed_results = joblib.load(f) postprocessed_results = pd.read_csv( postprocessed_result_path) return postprocessed_results def combine_postprocessed_results( config: dict, input_path: str, postprocessed_result_path: str, reference_path: str, output_path: str): """ Argument config: dict input_path: str postprocessed_result_path: str reference_path: str output_path: str """ file_pattern = str(Path(input_path).joinpath( postprocessed_result_path, "postprocessed_results.csv")) logger.info(f"file_pattern -> {file_pattern}") file_paths = list(glob.glob(file_pattern)) assert len(file_paths)>0, f"File not found: {postprocessed_result_path}" postprocessed_results = pd.concat([pd.read_csv(file_path, dtype={"business_id": str}) for file_path in file_paths], axis=0) reference_results = get_leads( reference_path) # reference_results = reference_results.rename(config['column_mapping'], axis=1) postprocessed_results = reference_results.merge( postprocessed_results, left_on = "統一編號", right_on="business_id", how="left") postprocessed_results.to_csv( output_path, index=False) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, default='config/config.yml', help="Path to the configuration file") parser.add_argument( "-t", "--task", type=str, default='prepare_batch', choices=['prepare_batch', 'prepare_regularization', 'run_batch', 'batch2extract', 'batch2reg', 'postprocess', 'combine']) parser.add_argument( "-i", "--input_path", type=str, default='', ) parser.add_argument( "-o", "--output_path", type=str, default='', ) parser.add_argument( "-b", "--batch_path", type=str, default='', ) parser.add_argument( "-j", "--job_path", type=str, default='', ) parser.add_argument( "-jp", "--jsonl_path", type=str, default='', ) parser.add_argument( "-crp", "--crawled_result_path", type=str, default='', ) parser.add_argument( "-erp", "--extracted_result_path", type=str, default='', ) parser.add_argument( "-rrp", "--regularized_result_path", type=str, default='', ) parser.add_argument( "-prp", "--postprocessed_result_path", type=str, default='', ) parser.add_argument( "-rp", "--reference_path", type=str, default='', ) parser.add_argument( "-topn", "--topn", type=int, default=None ) args = parser.parse_args() # classes = ['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', ] # backup_classes = [ '中式', '西式'] assert os.path.exists(args.config), f"File not found: {args.config}" config = yaml.safe_load(open(args.config, "r").read()) if args.task == 'prepare_batch': prepare_batch( crawled_result_path = args.crawled_result_path, config = config, output_path = args.output_path, topn = args.topn) elif args.task == 'run_batch': run_batch( input_path = args.input_path, job_path = args.job_path, jsonl_path = args.jsonl_path) elif args.task == 'prepare_regularization': prepare_regularization( extracted_result_path = args.extracted_result_path, config = config, output_path = args.output_path, topn = args.topn) elif args.task == 'batch2extract': batch2extract( jsonl_path = args.jsonl_path, crawled_result_path = args.crawled_result_path, extracted_result_path = args.extracted_result_path ) elif args.task == 'batch2reg': batch2reg( jsonl_path = args.jsonl_path, extracted_result_path = args.extracted_result_path, regularized_result_path = args.regularized_result_path ) elif args.task == 'postprocess': postprocess_result( config = config, regularized_result_path = args.regularized_result_path, postprocessed_result_path = args.postprocessed_result_path, category_hierarchy = config['category2supercategory'], column_name = 'category' ) elif args.task == 'combine': combine_postprocessed_results( config, args.input_path, args.postprocessed_result_path, args.reference_path, args.output_path ) else: raise Exception("Invalid task")