Spaces:
Running
Running
import gradio as gr | |
import json | |
import requests | |
import re | |
class Chatbot: | |
def __init__(self, config): | |
self.video_id = config.get('video_id') | |
self.content_subject = config.get('content_subject') | |
self.content_grade = config.get('content_grade') | |
self.jutor_chat_key = config.get('jutor_chat_key') | |
self.transcript_text = self.get_transcript_text(config.get('transcript')) | |
self.key_moments_text = self.get_key_moments_text(config.get('key_moments')) | |
self.ai_model_name = config.get('ai_model_name') | |
self.ai_client = config.get('ai_client') | |
self.instructions = config.get('instructions') | |
def get_transcript_text(self, transcript_data): | |
if isinstance(transcript_data, str): | |
transcript_json = json.loads(transcript_data) | |
else: | |
transcript_json = transcript_data | |
for entry in transcript_json: | |
entry.pop('end_time', None) | |
transcript_text = json.dumps(transcript_json, ensure_ascii=False) | |
return transcript_text | |
def get_key_moments_text(self, key_moments_data): | |
if isinstance(key_moments_data, str): | |
key_moments_json = json.loads(key_moments_data) | |
else: | |
key_moments_json = key_moments_data | |
# key_moments_json remove images | |
for moment in key_moments_json: | |
moment.pop('images', None) | |
moment.pop('end', None) | |
moment.pop('transcript', None) | |
key_moments_text = json.dumps(key_moments_json, ensure_ascii=False) | |
return key_moments_text | |
def chat(self, user_message, chat_history): | |
try: | |
messages = self.prepare_messages(chat_history, user_message) | |
system_prompt = self.instructions | |
system_prompt += "\n\n告知用戶你現在是誰,第一次加上科目學伴及名字,後面就只說名字就好,但不用每次都說,自然就好,不用每一句都特別說明,口氣請符合給予的人設,請用繁體中文回答" | |
service_type = self.ai_model_name | |
response_text = self.chat_with_service(service_type, system_prompt, messages) | |
except Exception as e: | |
print(f"Error: {e}") | |
response_text = "學習精靈有點累,請稍後再試!" | |
return response_text | |
def prepare_messages(self, chat_history, user_message): | |
messages = [] | |
if chat_history is not None: | |
if len(chat_history) > 10: | |
chat_history = chat_history[-10:] | |
for user_msg, assistant_msg in chat_history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
if user_message: | |
user_message += "/n (請一定要用繁體中文回答 zh-TW,並用台灣人的禮貌口語表達,回答時不要特別說明這是台灣人的語氣,不要提到「台灣腔」,不用提到「逐字稿」這個詞,用「內容」代替),回答時如果有用到數學式,請用數學符號代替純文字(Latex 用 $ 字號 render)" | |
messages.append({"role": "user", "content": user_message}) | |
return messages | |
def chat_with_service(self, service_type, system_prompt, messages): | |
if service_type == 'openai': | |
return self.chat_with_jutor(system_prompt, messages) | |
elif service_type == 'groq_llama3': | |
return self.chat_with_groq(service_type, system_prompt, messages) | |
elif service_type == 'groq_mixtral': | |
return self.chat_with_groq(service_type, system_prompt, messages) | |
elif service_type == 'claude3': | |
return self.chat_with_claude3(system_prompt, messages) | |
elif service_type in ['perplexity_sonar', 'perplexity_sonar_pro', 'perplexity_r1_1776']: | |
return self.chat_with_perplexity(service_type, system_prompt, messages) | |
else: | |
raise gr.Error("不支持的服务类型") | |
def chat_with_jutor(self, system_prompt, messages): | |
messages.insert(0, {"role": "system", "content": system_prompt}) | |
api_endpoint = "https://ci-live-feat-video-ai-dot-junyiacademy.appspot.com/api/v2/jutor/hf-chat" | |
headers = { | |
"Content-Type": "application/json", | |
"x-api-key": self.jutor_chat_key, | |
} | |
model = "gpt-4o" | |
print("======model======") | |
print(model) | |
data = { | |
"data": { | |
"messages": messages, | |
"max_tokens": 512, | |
"temperature": 0.9, | |
"model": model, | |
"stream": False, | |
} | |
} | |
response = requests.post(api_endpoint, headers=headers, data=json.dumps(data)) | |
response_data = response.json() | |
response_completion = response_data['data']['choices'][0]['message']['content'].strip() | |
return response_completion | |
def chat_with_groq(self, model_name, system_prompt, messages): | |
# system_prompt insert to messages 的最前面 {"role": "system", "content": system_prompt} | |
messages.insert(0, {"role": "system", "content": system_prompt}) | |
model_name_dict = { | |
"groq_llama3": "llama-3.1-70b-versatile", | |
"groq_mixtral": "mixtral-8x7b-32768" | |
} | |
model = model_name_dict.get(model_name) | |
print("======model======") | |
print(model) | |
request_payload = { | |
"model": model, | |
"messages": messages, | |
"max_tokens": 500 # 設定一個較大的值,可根據需要調整 | |
} | |
groq_client = self.ai_client | |
response = groq_client.chat.completions.create(**request_payload) | |
response_completion = response.choices[0].message.content.strip() | |
return response_completion | |
def chat_with_claude3(self, system_prompt, messages): | |
if not system_prompt.strip(): | |
raise ValueError("System prompt cannot be empty") | |
model_id = "anthropic.claude-3-sonnet-20240229-v1:0" | |
# model_id = "anthropic.claude-3-haiku-20240307-v1:0" | |
print("======model_id======") | |
print(model_id) | |
kwargs = { | |
"modelId": model_id, | |
"contentType": "application/json", | |
"accept": "application/json", | |
"body": json.dumps({ | |
"anthropic_version": "bedrock-2023-05-31", | |
"max_tokens": 500, | |
"system": system_prompt, | |
"messages": messages | |
}) | |
} | |
# 建立 message API,讀取回應 | |
bedrock_client = self.ai_client | |
response = bedrock_client.invoke_model(**kwargs) | |
response_body = json.loads(response.get('body').read()) | |
response_completion = response_body.get('content')[0].get('text').strip() | |
return response_completion | |
def chat_with_perplexity(self, service_type, system_prompt, messages): | |
"""使用 Perplexity API 進行對話""" | |
if not system_prompt.strip(): | |
raise ValueError("System prompt cannot be empty") | |
# 清理用戶訊息中的特殊指令 | |
for msg in messages: | |
if msg["role"] == "user": | |
# 移除可能導致問題的特殊指令 | |
msg["content"] = msg["content"].replace("/n", "\n") | |
# 移除括號內的特殊指令 | |
msg["content"] = re.sub(r'\(請一定要用繁體中文回答.*?\)', '', msg["content"]) | |
# 系統提示放在最前面 | |
clean_messages = [{"role": "system", "content": system_prompt}] | |
# 添加其他訊息 | |
for msg in messages: | |
if msg["role"] != "system": # 避免重複添加系統提示 | |
clean_messages.append(msg) | |
# 在系統提示中添加 Markdown 和 LaTeX 格式指導 | |
system_prompt += "\n\n重要:使用 LaTeX 數學符號時,請確保格式正確。數學表達式應該使用 $ 符號包圍,例如:$7 \\times 10^4$。不要使用 ** 符號來強調數字,而是使用 $ 符號,例如:$7$個萬 ($7 \\times 10000$)。不要使用 \\text 或 \\quad 等命令。" | |
# 根據服務類型選擇模型 | |
model_name_dict = { | |
"perplexity_sonar": "sonar", | |
"perplexity_sonar_pro": "sonar-pro", | |
"perplexity_r1_1776": "r1-1776" | |
} | |
model = model_name_dict.get(service_type, "sonar") | |
print("======model======") | |
print(model) | |
print("======clean_messages======") | |
print(json.dumps(clean_messages[:1], ensure_ascii=False)) # 只打印系統提示的前部分 | |
try: | |
perplexity_client = self.ai_client | |
# 針對 r1-1776 模型調整參數 | |
if service_type == "perplexity_r1_1776": | |
# 增加 max_tokens 並添加特殊指令 | |
response = perplexity_client.chat.completions.create( | |
model=model, | |
messages=clean_messages, | |
max_tokens=1000, # 增加 token 限制 | |
temperature=0.7, | |
top_p=0.9 | |
) | |
else: | |
response = perplexity_client.chat.completions.create( | |
model=model, | |
messages=clean_messages, | |
max_tokens=500, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
# 檢查回應是否為空 | |
if not hasattr(response, 'choices') or len(response.choices) == 0: | |
print("警告:API 回傳無效回應結構") | |
return "學習精靈暫時無法回答,請稍後再試!" | |
response_completion = response.choices[0].message.content | |
if not response_completion or response_completion.strip() == "": | |
print("警告:API 回傳空回應") | |
return "學習精靈暫時無法回答,請稍後再試!" | |
# 處理回應中的思考過程標籤和修正 LaTeX 格式 | |
response_completion = self._process_response(response_completion) | |
# 打印處理後的回應以便調試 | |
print("======processed_response======") | |
print(response_completion) | |
return response_completion.strip() | |
except Exception as e: | |
print(f"Perplexity API Error: {e}") | |
print(f"Error details: {str(e)}") | |
# 嘗試使用備用模型 | |
try: | |
if service_type == "perplexity_r1_1776": | |
print("嘗試使用備用模型 sonar") | |
backup_response = perplexity_client.chat.completions.create( | |
model="sonar", | |
messages=clean_messages, | |
max_tokens=500, | |
temperature=0.7 | |
) | |
backup_completion = backup_response.choices[0].message.content | |
backup_completion = self._process_response(backup_completion) | |
return backup_completion.strip() | |
except Exception as backup_error: | |
print(f"備用模型也失敗: {backup_error}") | |
return "學習精靈暫時無法回答,請稍後再試!" | |
def _process_response(self, response_text): | |
"""處理回應中的思考過程標籤和修正 LaTeX 格式""" | |
# 移除 <think>...</think> 區塊 | |
import re | |
response_text = re.sub(r'<think>.*?</think>', '', response_text, flags=re.DOTALL) | |
# 移除其他可能的標籤或指令 | |
response_text = re.sub(r'(偷偷說.*?)', '', response_text, flags=re.DOTALL) | |
# 修正 Markdown 格式 | |
# 1. 確保項目符號前後有正確的空格和換行 | |
response_text = re.sub(r'(\n|^)(\s*)([-•○●◦])\s*', r'\1\2\3 ', response_text) | |
# 2. 確保數字列表前後有正確的空格和換行 | |
response_text = re.sub(r'(\n|^)(\s*)(\d+\.)\s*', r'\1\2\3 ', response_text) | |
# 3. 修正 LaTeX 格式 | |
# 移除不正確的 LaTeX 命令 | |
response_text = re.sub(r'\\text\{([^}]+)\}', r'\1', response_text) | |
response_text = re.sub(r'\\quad', ' ', response_text) | |
# 4. 修正數學表達式 | |
# 確保數學表達式中的乘法符號格式正確 | |
response_text = re.sub(r'(\d+)個「([^」]+)」→\s*(\d+)\\times(\d+)', r'\1個「\2」→ $\3\\times\4$', response_text) | |
# 5. 修正單獨數字的 LaTeX 格式 | |
# 將單獨的數字包裹在 $ 符號中 | |
response_text = re.sub(r'([^$\d])(\d+)([^$\d\w])', r'\1$\2$\3', response_text) | |
# 6. 修正連續的 LaTeX 表達式 | |
# 確保連續的 LaTeX 表達式之間有空格 | |
response_text = re.sub(r'\$([^$]+)\$\$([^$]+)\$', r'$\1$ $\2$', response_text) | |
# 7. 移除單獨的 $ 符號 | |
response_text = re.sub(r'(?<!\$)\$(?!\$)\s*$', '', response_text) | |
response_text = re.sub(r'^\s*\$(?!\$)', '', response_text) | |
response_text = re.sub(r'(?<!\$)\$(?!\$)\s*\n', '\n', response_text) | |
# 8. 確保成對的 $ 符號 | |
dollar_count = response_text.count('$') | |
if dollar_count % 2 != 0: | |
# 如果 $ 符號數量為奇數,移除最後一個 $ | |
last_dollar_pos = response_text.rfind('$') | |
if last_dollar_pos != -1: | |
response_text = response_text[:last_dollar_pos] + response_text[last_dollar_pos+1:] | |
# 9. 修正錯誤的粗體標記 | |
# 將 **數字** 格式修正為正確的數字格式 | |
response_text = re.sub(r'\*\*(\d+)\*\*', r'$\1$', response_text) | |
# 如果處理後的回應為空,返回原始回應 | |
if not response_text.strip(): | |
return "學習精靈暫時無法回答,請稍後再試!" | |
return response_text | |