File size: 9,017 Bytes
60274d1
07d2942
 
948e91c
07d2942
 
 
60274d1
07d2942
60274d1
07d2942
 
 
60274d1
07d2942
 
 
60274d1
07d2942
 
 
 
 
 
 
60274d1
 
 
8aa4241
60274d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948e91c
 
8aa4241
60274d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07d2942
60274d1
 
 
 
 
07d2942
60274d1
 
 
 
 
 
07d2942
60274d1
 
07d2942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60274d1
 
 
 
 
 
 
07d2942
 
 
 
 
60274d1
07d2942
 
60274d1
07d2942
 
60274d1
 
07d2942
 
 
 
 
60274d1
 
 
 
 
07d2942
2547429
 
07d2942
 
 
2547429
07d2942
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import sys
import json
import time
import yaml
import joblib
import argparse

import jinja2
import anthropic
import pandas as pd
from tqdm import tqdm
from loguru import logger
from openai import OpenAI
from dotenv import load_dotenv
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

from utils import parse_json_garbage, compose_query

try:
    logger.remove(0)
    logger.add(sys.stderr, level="INFO")
except ValueError:
    pass

load_dotenv()

def llm( provider, model, system_prompt, user_content, delay:int = 0):
    """Invoke LLM service
    Argument
    --------
    provider: str
        openai or anthropic
    model: str
        Model name for the API
    system_prompt: str
        System prompt for the API
    user_content: str
        User prompt for the API
    Return
    ------
    response: str
    """
    if delay:
        time.sleep(delay)

    if provider=='openai':
        client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
        chat_completion = client.chat.completions.create( 
            messages=[
                {
                    "role": "system",
                    "content": system_prompt
                },
                {
                    "role": "user",
                    "content": user_content,
                }
            ],
            model = model, 
            response_format = {"type": "json_object"},
            temperature = 0,
            max_tokens = 4096,
            # stream = True
        )
        response = chat_completion.choices[0].message.content

    elif provider=='anthropic':
        client = anthropic.Client(api_key=os.getenv('ANTHROPIC_API_KEY'))
        response = client.messages.create(
            model= model,
            system= system_prompt,
            messages=[
                {"role": "user", "content": user_content} # <-- user prompt
            ],
            max_tokens = 4000
        )
        response = response.content[0].text

    elif provider=='google':
        genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
        model = genai.GenerativeModel(
            model_name = model,  
            system_instruction = system_prompt,
            generation_config={
            "temperature": 0,
            "max_output_tokens": 8192,
            "response_mime_type": "application/json"
        })
        safety_settings = {
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
        }
        messages = []
        # messages.append({
        #     'role':'user',
        #     'parts': [f"System instruction: {system_prompt}"]
        # })
        # response = model.generate_content(messages, safety_settings=safety_settings)
        # try:
        #     messages.append({
        #         'role': 'model',
        #         'parts': [response.text]
        #     })
        # except Exception as e:
        #     logger.error(f"response.candidates -> {response.candidates}")
        #     logger.error(f"error -> {e}")
        #     messages.append({
        #         'role': 'model',
        #         'parts': ["OK. I'm ready to help you."]
        #     })
        messages.append({
            'role': 'user',
            'parts': [user_content]
        })
        try:
            response = model.generate_content(messages, safety_settings=safety_settings, )
            response = response.text
        except Exception as e:
            logger.error(f"Error (will still return response) -> {e}")
            logger.error(f"response.candidates -> {response.candidates}")
            return response
    else:
        raise Exception("Invalid provider")
    
    return response

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument( "-c", "--config", type=str, default='config/config.yml', help="Path to the configuration file")
    parser.add_argument( "-t", "--task", type=str, default='prepare_batch', choices=['extract', 'classify'])
    parser.add_argument( "-i", "--input_path", type=str, default='', )
    parser.add_argument( "-o", "--output_path", type=str, default='', )
    parser.add_argument( "-topn", "--topn", type=int, default=None )
    args = parser.parse_args()
    # classes = ['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒',  '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', ]
    # backup_classes = [ '中式', '西式']

    assert os.path.exists(args.config), f"File not found: {args.config}"
    config = yaml.safe_load(open(args.config, "r").read())

    if args.task == 'extract':
        jenv = jinja2.Environment()
        template = jenv.from_string(config['extraction_prompt'])
        system_prompt = template.render( classes = config['classes'], traits = config['traits'])
        query = "山の迴饗"
        search_results = str([{"title": "山の迴饗", "snippet": "謝謝大家這麼支持山の迴饗 我們會繼續努力用心做出美味的料理 ————————— ⛰️ 山の迴饗地址:台東縣關山鎮中華路56號訂位專線:0975-957-056 · #山的迴饗 · #夢想起飛"}, {"title": "山的迴饗餐館- 店家介紹", "snippet": "營業登記資料 · 統一編號. 92433454 · 公司狀況. 營業中 · 公司名稱. 山的迴饗餐館 · 公司類型. 獨資 · 資本總額. 30000 · 所在地. 臺東縣關山鎮中福里中華路56號 · 使用發票."}, {"title": "關山漫遊| 💥山の迴饗x night bar", "snippet": "山の迴饗x night bar 即將在12/1號台東關山開幕! 別再煩惱池上、鹿野找不到宵夜餐酒館 各位敬請期待並關注我們✨ night bar❌山的迴饗 12/1 ..."}, {"title": "山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵 - 台灣美食網", "snippet": "山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵|台式三杯雞|滷肉飯|便當|CP美食營業時間 ; 星期一, 休息 ; 星期二, 10:00–14:00 16:00–21:00 ; 星期三, 10:00–14:00 16:00– ..."}, {"title": "便當|CP美食- 山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵", "snippet": "餐廳山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵|台式三杯雞|滷肉飯|便當|CP美食google map 導航. 臺東縣關山鎮中華路56號 +886 975 957 056 ..."}, {"title": "山的迴饗餐館", "snippet": "山的迴饗餐館,統編:92433454,地址:臺東縣關山鎮中福里中華路56號,負責人姓名:周偉慈,設立日期:112年11月15日."}, {"title": "山的迴饗餐館", "snippet": "山的迴饗餐館. 資本總額(元), 30,000. 負責人, 周偉慈. 登記地址, 看地圖 臺東縣關山鎮中福里中華路56號 郵遞區號查詢. 設立日期, 2023-11-15. 資料管理 ..."}, {"title": "山的迴饗餐館, 公司統一編號92433454 - 食品業者登錄資料集", "snippet": "公司或商業登記名稱山的迴饗餐館的公司統一編號是92433454, 登錄項目是餐飲場所, 業者地址是台東縣關山鎮中福里中華路56號, 食品業者登錄字號是V-202257990-00001-5."}, {"title": "山的迴饗餐館, 公司統一編號92433454 - 食品業者登錄資料集", "snippet": "公司或商業登記名稱山的迴饗餐館的公司統一編號是92433454, 登錄項目是公司/商業登記, 業者地址是台東縣關山鎮中福里中華路56號, 食品業者登錄字號是V-202257990-00000-4 ..."}, {"title": "山的迴饗餐館", "snippet": "負責人, 周偉慈 ; 登記地址, 台東縣關山鎮中福里中華路56號 ; 公司狀態, 核准設立 「查詢最新營業狀況請至財政部稅務入口網 」 ; 資本額, 30,000元 ; 所在縣市 ..."}, {"title": "山的迴饗 | 關山美食|焗烤飯|酒吧|義大利麵|台式三杯雞|滷肉飯|便當|CP美食", "顧客評價": "324晚餐餐點豬排簡餐加白醬焗烤等等餐點。\t店家也提供免費的紅茶 綠茶 白開水 多種的調味料自取 總而言之 CP值真的很讚\t空間舒適涼爽,店員服務周到"}, {"title": "類似的店", "snippet": "['中國菜']\t['客家料理']\t['餐廳']\t['熟食店']\t['餐廳']"}, {"telephone_number": "0975 957 056"}])
        user_content = f'''
            `query`: `{query}`,
            `search_results`: {search_results}
        '''
        print(f"user_content -> {user_content}")
        resp = llm( config['provider'], config['model'], system_prompt, user_content)
        print(resp)

    elif args.task == 'classify':
        system_prompt = config['classification_prompt']

    else:
        raise Exception("Invalid task")