Spaces:
Runtime error
Runtime error
import os | |
import time | |
import json | |
import joblib | |
import math | |
import itertools | |
import argparse | |
import multiprocessing as mp | |
import pandas as pd | |
from dotenv import load_dotenv | |
from serpapi import GoogleSearch | |
import tiktoken | |
from openai import OpenAI | |
from tqdm import tqdm | |
from model import llm | |
from utils import parse_json_garbage | |
load_dotenv() | |
ORGANIZATION_ID = os.getenv('OPENAI_ORGANIZATION_ID') | |
SERP_API_KEY = os.getenv('SERP_APIKEY') | |
def get_leads( file_path: str, names: list = ['營業地址', '統一編號', '總機構統一編號', '營業人名稱', '資本額', '設立日期', '組織別名稱', '使用統一發票', | |
'行業代號', '名稱', '行業代號1', '名稱1', '行業代號2', '名稱2', '行業代號3', '名稱3']): | |
""" | |
""" | |
assert os.path.exists(file_path) | |
data = pd.read_csv( file_path, names=names) | |
return data | |
def get_serp( query: str, google_domain: str, gl: str, lr: str) -> dict: | |
""" | |
""" | |
results = [] | |
search = GoogleSearch({ | |
"q": query, | |
'google_domain': google_domain, | |
'gl': gl, | |
'lr': lr, | |
"api_key": SERP_API_KEY | |
}) | |
result = search.get_dict() | |
# print(result['organic_results'][0]) | |
# return result['organic_results'][0] | |
return result | |
def get_condensed_result(result): | |
""" | |
Argument | |
result | |
Return | |
condensed_result: | |
Example: | |
result['knowledge_graph'].keys() # 'title', 'thumbnail', 'type', 'entity_type', 'kgmid', 'knowledge_graph_search_link', 'serpapi_knowledge_graph_search_link', 'tabs', 'place_id', 'directions', 'local_map', 'rating', 'review_count', '服務項目', '地址', '地址_links', 'raw_hours', 'hours', '電話號碼', '電話號碼_links', 'popular_times', 'user_reviews', 'reviews_from_the_web', 'unclaimed_listing', '個人資料', '其他人也搜尋了以下項目', '其他人也搜尋了以下項目_link', '其他人也搜尋了以下項目_stick' | |
""" | |
filtered_results = [ | |
{"title": r.get('title',""), 'snippet': r.get('snippet',"")} for r in result['organic_results'] | |
] | |
if 'knowledge_graph' in result: | |
if 'user_reviews' in result['knowledge_graph']: | |
filtered_results.append( {'title': result['knowledge_graph']['title'], '顧客評價': "\t".join([ _.get('summary', '') for _ in result['knowledge_graph']['user_reviews']]) }) | |
if '其他人也搜尋了以下項目' in result['knowledge_graph']: | |
filtered_results.append( {'title': "類似的店", 'snippet': "\t".join([ str(_.get('extensions', '')) for _ in result['knowledge_graph']['其他人也搜尋了以下項目']]) }) | |
if '暫停營業' in result['knowledge_graph']: | |
filtered_results.append( {'status': '暫停營業' if result['knowledge_graph']['暫停營業'] else '營業中'}) | |
if '電話號碼' in result['knowledge_graph']: | |
filtered_results.append( {'telephone_number': result['knowledge_graph']['電話號碼']}) | |
condensed_result = json.dumps(filtered_results, ensure_ascii=False) | |
# print( condensed_results ) | |
return condensed_result | |
def compose_extraction( query, search_results, classes: list, provider: str, model: str): | |
""" | |
Argument | |
query: str | |
search_results: str | |
system_prompt: str | |
classes: list, `小吃店`, `日式料理(含居酒屋,串燒)`, `火(鍋/爐)`, `東南亞料理(不含日韓)`, `海鮮熱炒`, `特色餐廳(含雞、鵝、牛、羊肉)`, `傳統餐廳`, `燒烤`, `韓式料理(含火鍋,烤肉)`, `西餐廳(含美式,義式,墨式)`, `西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)`, `西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)` or `早餐` | |
provider: "openai" | |
model: "gpt-4-0125-preview" or 'gpt-3.5-turbo-0125' | |
Return | |
response: str | |
""" | |
classes = ", ".join([ "`"+x+"`" for x in classes if x!='早餐' ])+ " or " + "`早餐`" | |
system_prompt = f''' | |
As a helpful and rigorous retail analyst, given the provided query and a list of search results for the query, | |
your task is to first identify relevant information of the identical store based on store name and proxmity of address if known. After that, extract `store_name`, `address`, `description`, `category` and `phone_number` from the found relevant information, where `category` can only be {classes}. | |
It's very important to omit unrelated results. Do not make up any assumption. | |
Please think step by step, and output in json format. An example output json is like {{"store_name": "...", "address": "...", "description": "... products, service or highlights ...", "category": "...", "phone_number": "..."}} | |
If no relevant information has been found, simply output json with empty values. | |
I'll tip you and guarantee a place in heaven you do a great job completely according to my instruction. | |
''' | |
user_content = f"`query`: `{query}`\n`search_results`: {search_results}" | |
response = llm( | |
provider = provider, | |
model = model, | |
system_prompt = system_prompt, | |
user_content = user_content | |
) | |
return response | |
def compose_classication( user_content, classes: list, backup_classes: list, provider: str, model: str) -> str: | |
""" | |
Argument | |
client: | |
evidence: str | |
classes: list | |
provider: e.g. 'openai' | |
model: e.g. 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview' | |
Return | |
response: str | |
""" | |
if isinstance(classes, list): | |
classes = ", ".join([ f"`{x}`" for x in classes]) | |
elif isinstance(classes, str): | |
pass | |
else: | |
raise Exception(f"Incorrect classes type: {type(classes)}") | |
system_prompt = f""" | |
As a helpful and rigorous retail analyst, given the provided information about a store, | |
your task is two-fold. First, classify provided evidence below into the mostly relevant category from the following: {classes}. | |
Second, if no relevant information has been found, classify the evidence into the mostly relevant supercategory from the following: {backup_classes}. | |
It's very important to omit unrelated piece of evidence and don't make up any assumption. | |
Please think step by step, and must output in json format. An example output json is like {{"category": "..."}} | |
If no relevant piece of information can ever be found at all, simply output json with empty string "". | |
I'll tip you and guarantee a place in heaven you do a great job completely according to my instruction. | |
""" | |
response = llm( | |
provider = provider, | |
model = model, | |
system_prompt = system_prompt, | |
user_content = user_content, | |
) | |
return response | |
def classify_results( | |
analysis_results: pd.DataFrame, | |
classes: list, | |
backup_classes: list, | |
provider: str, | |
model: str, | |
input_column: str = 'evidence', | |
output_column: str = 'classified_category', | |
): | |
"""Classify the results | |
Argument | |
analysis_results: dataframe | |
input_column: str | |
output_column: str | |
classes: list | |
Return | |
analysis_results: dataframe | |
""" | |
classified_results = analysis_results.copy() | |
labels, empty_indices = [], [] | |
for idx, evidence in zip( analysis_results['index'], analysis_results[input_column]): | |
try: | |
user_content = f'''`evidence`: `{evidence}`''' | |
pred_cls = compose_classication( user_content, classes=classes, backup_classes=backup_classes, provider=provider, model=model) | |
label = parse_json_garbage(pred_cls)['category'] | |
labels.append(label) | |
except Exception as e: | |
print(f"# CLASSIFICATION error: e -> {e}, user_content -> {user_content}, evidence: {evidence}") | |
labels.append("") | |
empty_indices.append(idx) | |
classified_results[output_column] = labels | |
return { | |
"classified_results": classified_results, | |
"empty_indices": empty_indices | |
} | |
def classify_results_mp( extracted_results: pd.DataFrame, classified_file_path: str, classes: list, backup_classes: list, provider: str, model: str, n_processes: int = 4): | |
""" | |
Argument | |
extracted_results: | |
classified_file_path: | |
classes: e.g. ['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)'] | |
backup_classes: e.g. [ '中式', '西式'] | |
provider: | |
model: | |
n_processes: int | |
Return | |
classified_results: dataframe | |
Reference | |
200 records, 4 processes, 122.4695s | |
""" | |
st = time.time() | |
# classified_file_path = "data/classified_result.joblib" | |
if not os.path.exists(classified_file_path): | |
split_data = split_dataframe(extracted_results) | |
with mp.Pool(args.n_processes) as pool: | |
classified_results = pool.starmap( | |
classify_results, | |
[ ( | |
d, | |
classes, backup_classes, | |
provider, model, | |
'evidence', 'classified_category', | |
) for d in split_data] | |
) | |
classified_results = merge_results( classified_results, dataframe_columns=['classified_results'], list_columns=['empty_indices']) | |
with open( classified_file_path, "wb") as f: | |
joblib.dump( classified_results, f) | |
else: | |
with open( classified_file_path, "rb") as f: | |
classified_results = joblib.load(f) | |
print( f"total time: {time.time() - st}") | |
return classified_results | |
def compose_query( address, name, with_index: bool = True, exclude: str = "-inurl:twincn.com -inurl:findcompany.com.tw -inurl:iyp.com.tw -inurl:twypage.com -inurl:alltwcompany.com -inurl:zhupiter.com -inurl:twinc.com.tw"): | |
""" | |
Argumemnt | |
# d: series with d[1]: 地址, d[4]: 營業人名稱 # | |
address: str | |
name: str | |
with_index: bool | |
Return | |
query: `縣市` `營業人名稱` | |
""" | |
# if with_index: # .itertuples() | |
# query = f"{d[1][:3]} {d[4]}" | |
# else: | |
# query = f"{d[0][:3]} {d[3]}" | |
query = f"{address[:3]} {name} {exclude}" | |
return query | |
def crawl_results( data: pd.DataFrame, google_domain: str = 'google.com.tw', gl: str = 'tw', lr: str = 'lang_zh-TW'): | |
""" | |
Argument | |
data: dataframe | |
google_domain: str | |
gl: str | |
lr: str | |
Return | |
crawled_results | |
Reference | |
200 records, 4 processes, 171.36490321159363 | |
""" | |
serp_results = [] | |
condensed_results = [] | |
crawled_results = [] | |
empty_indices = [] | |
for i, d in tqdm(enumerate(data.itertuples())): | |
idx = d[0] | |
address = d[1] | |
business_id = d[2] | |
business_name = d[4] | |
query = compose_query(address, business_name) | |
try: | |
res = get_serp( query, google_domain, gl, lr) | |
serp_results.append(res) | |
except: | |
print( f"# SERP error: i = {i}, idx = {idx}, query = {query}") | |
empty_indices.append(i) | |
continue | |
try: | |
cond_res = get_condensed_result(res) | |
condensed_results.append(cond_res) | |
except: | |
print(f"# CONDENSE error: i = {i}, idx = {idx}, res = {res}") | |
empty_indices.append(i) | |
continue | |
crawled_results.append( { | |
"index": idx, | |
"business_id": business_id, | |
"business_name": business_name, | |
"serp": res, | |
"evidence": cond_res, | |
"address": address, | |
} ) | |
crawled_results = pd.DataFrame(crawled_results) | |
return { | |
"crawled_results": crawled_results, | |
"empty_indices": empty_indices | |
} | |
def crawl_results_mp( data: pd.DataFrame, crawl_file_path: str, n_processes: int = 4): | |
st = time.time() | |
# crawl_file_path = "data/crawled_results.joblib" | |
if not os.path.exists(crawl_file_path): | |
split_data = split_dataframe( data ) | |
with mp.Pool(n_processes) as pool: | |
crawled_results = pool.map( crawl_results, split_data) | |
crawled_results = merge_results( crawled_results, dataframe_columns=['crawled_results'], list_columns=['empty_indices']) | |
with open( crawl_file_path, "wb") as f: | |
joblib.dump( crawled_results, f) | |
else: | |
with open( crawl_file_path, "rb") as f: | |
crawled_results = joblib.load(f) | |
print( f"total time: {time.time() - st}") | |
return crawled_results | |
def extract_results( data: pd.DataFrame, classes: list, provider: str, model: str): | |
""" | |
Argument | |
data: `evidence`, `result` | |
Return | |
extracted_results: dataframe of `extracted_evidence` | |
""" | |
extracted_results, empty_indices, ext_res = [], [], [] | |
for i, d in tqdm(enumerate(data.itertuples())): | |
idx = d[1] | |
evidence = d.evidence | |
business_id = d[2] | |
business_name = d[3] | |
address = d[6] | |
ana_res = None | |
query = compose_query( address, business_name) | |
try: | |
ext_res = compose_extraction( query = query, search_results = evidence, classes = classes, provider = provider, model = model) | |
ext_res = parse_json_garbage(ext_res) | |
except Exception as e: | |
print(f"# ANALYSIS error: e = {e}, i = {i}, q = {query}, ext_res = {ext_res}") | |
empty_indices.append(i) | |
continue | |
extracted_results.append( { | |
"index": idx, | |
"business_id": business_id, | |
"business_name": business_name, | |
"evidence": evidence, | |
** ext_res | |
} ) | |
extracted_results = pd.DataFrame(extracted_results) | |
return { | |
"extracted_results": extracted_results, | |
"empty_indices": empty_indices | |
} | |
def extract_results_mp( crawled_results, extracted_file_path, classes: list, provider: str, model: str, n_processes: int = 4): | |
""" | |
Argument | |
crawled_results: dataframe | |
extracted_file_path | |
classes: list | |
Return | |
Reference | |
200 records, 4 processes, 502.26914715766907 | |
""" | |
st = time.time() | |
# args.extracted_file_path = "data/extracted_results.joblib" | |
if not os.path.exists(extracted_file_path): | |
split_data = split_dataframe( crawled_results) | |
with mp.Pool(n_processes) as pool: | |
extracted_results = pool.starmap( extract_results, [ (x, classes, provider, model) for x in split_data]) | |
extracted_results = merge_results( extracted_results, dataframe_columns=['extracted_results'], list_columns=['empty_indices']) | |
with open( extracted_file_path, "wb") as f: | |
joblib.dump( extracted_results, f) | |
else: | |
with open( extracted_file_path, "rb") as f: | |
extracted_results = joblib.load(f) | |
print( f"total time: {time.time() - st}") | |
return extracted_results | |
def postprocess_result( results: pd.DataFrame, postprocessed_results_path, category_hierarchy: dict, column_name: str = 'category'): | |
""" | |
Argument | |
analysis_result: `evidence`, `result` | |
postprocessed_results_path | |
Return | |
""" | |
# index = analysis_result['result']['index'] | |
# store_name = data.loc[index]['營業人名稱'] if len(analysis_result['result'].get('store_name',''))==0 else analysis_result['result']['store_name'] | |
# address = data.loc[index]['營業地址'] if len(analysis_result['result'].get('address',''))==0 else analysis_result['result']['address'] | |
# post_res = { | |
# "evidence": analysis_result['evidence'], | |
# "index": index, | |
# "begin_date": data.loc[index]['設立日期'], | |
# "store_name": store_name, | |
# "address": address, | |
# "description": analysis_result['result'].get('description', ""), | |
# "phone_number": analysis_result['result'].get('phone_number', ""), | |
# "category": analysis_result['result'].get('category', ""), | |
# "supercategory": category_hierarchy.get(analysis_result['result'].get('category', ""), analysis_result['result'].get('category',"")), | |
# } | |
if not os.path.exists(postprocessed_results_path): | |
postprocessed_results = results.copy() | |
postprocessed_results['supercategory'] = postprocessed_results[column_name].apply(lambda x: category_hierarchy.get(x, '')) | |
with open( postprocessed_results_path, "wb") as f: | |
joblib.dump( postprocessed_results, f) | |
else: | |
with open( postprocessed_results_path, "rb") as f: | |
postprocessed_results = joblib.load(f) | |
return postprocessed_results | |
def combine_results( results: pd.DataFrame, combined_results_path: str, src_column: str = 'classified_category', tgt_column: str = 'category', strategy: str = 'replace'): | |
""" | |
Argument | |
classified_results_df: dataframe | |
combined_results_path | |
src_column: str | |
strategy: str, 'replace' or 'patch' | |
Return | |
combined_results: dataframe | |
""" | |
if not os.path.exists(combined_results_path): | |
combined_results = results.copy() | |
if strategy == 'replace': | |
condition = (combined_results[tgt_column]=='') | (combined_results[src_column]!=combined_results[tgt_column]) | |
combined_results.loc[ condition, tgt_column] = combined_results[condition][src_column].values | |
elif strategy == 'patch': | |
condition = (combined_results[tgt_column]=='') | |
combined_results.loc[ condition, tgt_column] = combined_results[condition][src_column].values | |
else: | |
raise Exception(f"Strategy {strategy} not implemented") | |
with open( combined_results_path, "wb") as f: | |
joblib.dump( combined_results, f) | |
else: | |
with open( combined_results_path, "rb") as f: | |
combined_results = joblib.load(f) | |
return combined_results | |
def format_evidence(evidence): | |
""" | |
""" | |
formatted = [] | |
evidence = json.loads(evidence) | |
# print( len(evidence) ) | |
for i in range(len(evidence)): | |
if 'title' in evidence[i] and '顧客評價' in evidence[i]: | |
f = f"\n> 顧客評價: {evidence[i]['顧客評價']}" | |
elif 'title' in evidence[i] and evidence[i]['title']=='類似的店': | |
f = f"\n> 類似的店: {evidence[i]['snippet']}" | |
elif 'status' in evidence[i]: | |
f = f"\n> 經營狀態: {evidence[i]['status']}" | |
elif 'telephone_number' in evidence[i]: | |
f = f"\n> 電話號碼: {evidence[i]['telephone_number']}" | |
else: | |
try: | |
f = f"{i+1}. {evidence[i]['title']} ({evidence[i].get('snippet','')})" | |
except KeyError: | |
print( evidence[i] ) | |
raise KeyError | |
formatted.append(f) | |
return "\n".join(formatted) | |
def format_output( df: pd.DataFrame, input_column: str = 'evidence', output_column: str = 'formatted_evidence', format_func = format_evidence): | |
""" | |
Argument | |
df: `evidence`, `result` | |
input_column: | |
output_column: | |
format_func: | |
Return | |
formatted_df: dataframe of `formatted_evidence` | |
""" | |
formatted_df = df.copy() | |
formatted_df[output_column] = formatted_df[input_column].apply(format_evidence) | |
return formatted_df | |
def merge_results( results: list, dataframe_columns: list, list_columns: list): | |
""" | |
Argument | |
results: a list of dataframes | |
dataframe_columns: list | |
list_columns: list | |
""" | |
assert len(results) > 0, "No results to merge" | |
merged_results = {} | |
for result in results: | |
for key in dataframe_columns: | |
mer_res = pd.concat([ r[key] for r in results], ignore_index=True) | |
merged_results[key] = mer_res | |
for key in list_columns: | |
mer_res = list(itertools.chain(*[ r[key] for r in results])) | |
merged_results[key] = mer_res | |
return merged_results | |
def split_dataframe( df: pd.DataFrame, n_processes: int = 4) -> list: | |
""" | |
""" | |
n = df.shape[0] | |
n_per_process = math.ceil(n / n_processes) | |
return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)] | |
def continue_missing(args): | |
""" | |
""" | |
data = get_leads(args.data_path) | |
n_data = data.shape[0] | |
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path) | |
formatted_results = pd.read_csv(formatted_results_path) | |
missing_indices = [] | |
for i in range(n_data): | |
if i not in formatted_results['index'].unique(): | |
print(f"{i} is not found") | |
missing_indices.append(i) | |
crawled_results_path = os.path.join( args.output_dir, args.crawled_file_path) | |
crawled_results = joblib.load( open( crawled_results_path, "rb")) | |
crawled_results = crawled_results['crawled_results'].query( f"index in {missing_indices}") | |
print( crawled_results) | |
er = extract_results( crawled_results, classes = args.classes, provider = args.provider, model = args.model) | |
er = er['extracted_results'] | |
print(er['category']) | |
postprossed_results = postprocess_result( | |
er, | |
"/tmp/postprocessed_results.joblib", | |
category2supercategory | |
) | |
out_formatted_results = format_output( | |
postprossed_results, | |
input_column = 'evidence', | |
output_column = 'formatted_evidence', | |
format_func = format_evidence | |
) | |
out_formatted_results.to_csv( "/tmp/formatted_results.missing.csv", index=False) | |
formatted_results = pd.concat([formatted_results, out_formatted_results], ignore_index=True) | |
formatted_results.sort_values(by='index', ascending=True, inplace=True) | |
formatted_results.to_csv( "/tmp/formatted_results.csv", index=False) | |
def main(args): | |
""" | |
Argument | |
args: argparse | |
Note | |
200 records | |
crawl: 585.3285548686981 | |
extract: 2791.631685256958(delay = 10) | |
classify: 2374.4915606975555(delay = 10) | |
""" | |
crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path) | |
extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path) | |
classified_file_path = os.path.join( args.output_dir, args.classified_file_path) | |
combined_file_path = os.path.join( args.output_dir, args.combined_file_path) | |
postprocessed_results = os.path.join( args.output_dir, args.postprocessed_results) | |
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path) | |
## 讀取資料名單 ## | |
data = get_leads(args.data_path) | |
## 進行爬蟲與分析 ## | |
crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes) | |
# crawled_results = { k:v[-5:] for k,v in crawled_results.items()} | |
## 方法 1: 擷取關鍵資訊與分類 ## | |
extracted_results = extract_results_mp( | |
crawled_results = crawled_results['crawled_results'], | |
extracted_file_path = extracted_file_path, | |
classes = args.classes, | |
provider = args.provider, | |
model = args.model, | |
n_processes = args.n_processes | |
) | |
## 方法2: 直接對爬蟲結果分類 ## | |
classified_results = classify_results_mp( | |
extracted_results['extracted_results'], | |
classified_file_path, | |
classes = args.classes, | |
backup_classes = args.backup_classes, | |
provider = args.provider, | |
model = args.model, | |
n_processes = args.n_processes | |
) | |
## 合併分析結果 ## | |
combined_results = combine_results( | |
classified_results['classified_results'], | |
combined_file_path, | |
src_column = 'classified_category', | |
tgt_column = 'category', | |
strategy = args.strategy | |
) | |
## 後處理分析結果 ## | |
postprossed_results = postprocess_result( | |
combined_results, | |
postprocessed_results, | |
category2supercategory | |
) | |
formatted_results = format_output( postprossed_results, input_column = 'evidence', output_column = 'formatted_evidence', format_func = format_evidence) | |
formatted_results.to_csv( formatted_results_path, index=False) | |
category2supercategory = { | |
"小吃店": "中式", | |
"日式料理(含居酒屋,串燒)": "中式", | |
"火(鍋/爐)": "中式", | |
"東南亞料理(不含日韓)": "中式", | |
"海鮮熱炒": "中式", | |
"特色餐廳(含雞、鵝、牛、羊肉)": "中式", | |
"傳統餐廳": "中式", | |
"燒烤": "中式", | |
"韓式料理(含火鍋,烤肉)": "中式", | |
"西餐廳(含美式,義式,墨式)": "西式", | |
"中式": "中式", | |
"西式": "西式", | |
"西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)": "西式", | |
"西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)": "西式", | |
"早餐": "" | |
} | |
supercategory2category = { | |
"中式": [ | |
"小吃店", | |
"日式料理(含居酒屋,串燒)", | |
"火(鍋/爐)", | |
"東南亞料理(不含日韓)", | |
"海鮮熱炒", | |
"特色餐廳(含雞、鵝、牛、羊肉)", | |
"傳統餐廳", | |
"燒烤", | |
"韓式料理(含火鍋,烤肉)" | |
], | |
"西式": ["西餐廳(含美式,義式,墨式)", "西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)", "西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)"], | |
"": ["早餐"] | |
} | |
if __name__=='__main__': | |
base = "https://serpapi.com/search.json" | |
engine = 'google' | |
# query = "Coffee" | |
google_domain = 'google.com.tw' | |
gl = 'tw' | |
lr = 'lang_zh-TW' | |
# url = f"{base}?engine={engine}&q={query}&google_domain={google_domain}&gl={gl}&lr={lr}" | |
n_processes = 4 | |
client = OpenAI( organization = ORGANIZATION_ID) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv") | |
parser.add_argument("--task", type=str, default="new", choices = ["new", "continue"], help="new or continue") | |
parser.add_argument("--output_dir", type=str, help='output directory') | |
parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib") | |
parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib") | |
parser.add_argument("--crawled_file_path", type=str, default="crawled_results.joblib") | |
parser.add_argument("--combined_file_path", type=str, default="combined_results.joblib") | |
parser.add_argument("--postprocessed_results", type=str, default="postprocessed_results.joblib") | |
parser.add_argument("--formatted_results_path", type=str, default="formatted_results.csv") | |
parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐']) | |
parser.add_argument("--backup_classes", type=list, default=['中式', '西式']) | |
parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch']) | |
parser.add_argument("--provider", type=str, default='openai', choices=['openai', 'anthropic']) | |
parser.add_argument("--model", type=str, default='gpt-4-0125-preview', choices=['claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview']) | |
parser.add_argument("--n_processes", type=int, default=4) | |
args = parser.parse_args() | |
if args.task == 'new': | |
main(args) | |
elif args.task == 'continue': | |
continue_missing(args) | |
else: | |
raise Exception(f"Task {args.task} not implemented") | |