|
from pathlib import Path |
|
import os |
|
import requests |
|
from typing import List, Dict, Any, Optional, Iterator |
|
from PIL import Image |
|
import re |
|
|
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings |
|
|
|
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolMetadata |
|
from llama_index.core.agent import ReActAgent |
|
|
|
from openai import OpenAI |
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
import streamlit as st |
|
|
|
from llama_index.core import PromptTemplate |
|
react_system_header_str = """\ |
|
|
|
你被设计为能够帮助完成各种任务,从回答问题到提供摘要,再到其他类型的分析。 |
|
|
|
## 工具 |
|
你可以访问多种工具。你需要根据任务的需要,按适当的顺序使用这些工具来完成任务。 |
|
这可能需要将任务分解为子任务,并为每个子任务使用不同的工具。 |
|
|
|
你可以使用以下工具: |
|
{tool_desc} |
|
|
|
## 输出格式 |
|
为了回答问题,请使用以下格式,请务必以 Thought 开始: |
|
|
|
``` |
|
Thought: 我需要使用一个工具来帮助回答这个问题。 |
|
Action: 工具名称(从 {tool_names} 中选择)。 |
|
Action Input: 提供给工具的输入,使用 JSON 格式表示 kwargs(例如:{{"input": "你好世界", "num_beams": 5}})。 |
|
``` |
|
|
|
请务必以 Thought 开始。 |
|
|
|
请确保 Action Input 使用有效的 JSON 格式。不要使用 {{'input': '你好世界', 'num_beams': 5}}。 |
|
|
|
如果使用了这种格式,用户会以以下格式进行回复: |
|
|
|
``` |
|
Observation: 工具的返回结果 |
|
``` |
|
|
|
你应该重复以上格式,直到你有足够的信息来回答问题而无需再使用工具。此时,你必须以以下两种格式之一进行回复: |
|
|
|
``` |
|
Thought: (Implicit) I can answer without any more tools! Answer: [你的答案在这里] |
|
Thought: (Implicit) I can answer without any more tools! Answer: 抱歉,我无法回答你的问题。 |
|
``` |
|
|
|
## 附加规则 |
|
- 答案必须包含一系列项目符号(即“•”),用来解释你是如何得出答案的。这可以包括先前对话历史中的内容。 |
|
- 你必须遵守每个工具的函数签名。如果函数需要参数,请不要传入空参数。 |
|
|
|
## 当前对话 |
|
以下是由用户和助手消息交替组成的当前对话: |
|
|
|
|
|
""" |
|
|
|
|
|
class PromptEngineerAgent: |
|
"""专门用于优化提示词的代理""" |
|
def __init__(self, llm): |
|
self.llm = llm |
|
|
|
def optimize_image_prompt(self, user_input: str) -> str: |
|
""" |
|
将用户的图像需求转换为优化的stable-diffusion提示词 |
|
""" |
|
prompt_template = f""" |
|
请将以下用户的图像需求转换为stable-diffusion所需的文生图提示词。 |
|
|
|
用户需求: {user_input} |
|
|
|
请生成一个优化的英文提示词,格式要求: |
|
1. 使用详细的描述性语言 |
|
2. 包含具体的艺术风格 |
|
3. 说明构图和视角 |
|
4. 描述光影和氛围 |
|
5. 添加相关的艺术家参考或风格类型 |
|
|
|
提示词: |
|
""" |
|
|
|
response = self.llm.complete(prompt_template) |
|
return str(response) |
|
|
|
def optimize_voice_prompt(self, user_input: str) -> Dict[str, str]: |
|
""" |
|
优化语音合成的参数 |
|
""" |
|
prompt_template = f""" |
|
请分析以下文本,并提供优化的语音合成参数。 |
|
|
|
文本: {user_input} |
|
|
|
请考虑: |
|
1. 最适合的语言 |
|
2. 说话的语速 |
|
3. 语气特点 |
|
4. 情感色彩 |
|
|
|
以JSON格式返回参数: |
|
""" |
|
|
|
response = self.llm.complete(prompt_template) |
|
try: |
|
params = eval(str(response)) |
|
return params |
|
except: |
|
return {"lang": "zh", "speed": 1.0} |
|
|
|
|
|
class MultiModalAssistant: |
|
def __init__(self, data_source_dir, llm, api_key): |
|
""" |
|
初始化助手,设置必要的API密钥和加载文档 |
|
""" |
|
|
|
|
|
self.llm = llm |
|
self.__api_key = api_key |
|
|
|
self.prompt_engineer = PromptEngineerAgent(self.llm) |
|
|
|
|
|
documents = SimpleDirectoryReader(data_source_dir, recursive=False, required_exts=[".txt"]).load_data() |
|
self.index = VectorStoreIndex.from_documents( |
|
documents |
|
) |
|
|
|
|
|
self.query_engine = self.index.as_query_engine(similarity_top_k=3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools = [ |
|
FunctionTool.from_defaults( |
|
fn=self.rag_query, |
|
name="rag_tool", |
|
description="无法直接回答时,查询和《黑神话:悟空》有关知识的工具" |
|
), |
|
FunctionTool.from_defaults( |
|
fn=self.text_to_speech, |
|
name="tts_tool", |
|
description="将文本转换为语音的工具" |
|
), |
|
FunctionTool.from_defaults( |
|
fn=self.generate_image, |
|
name="image_tool", |
|
description="生成图像的工具" |
|
) |
|
] |
|
|
|
|
|
self.agent = ReActAgent.from_tools( |
|
tools, |
|
llm=self.llm, |
|
verbose=True, |
|
max_iterations=10, |
|
max_function_calls=12, |
|
) |
|
|
|
|
|
react_system_prompt = PromptTemplate(react_system_header_str) |
|
self.agent.update_prompts({"agent_worker:system_prompt": react_system_prompt}) |
|
|
|
|
|
self.image_url = None |
|
self.audio_save_file = "audio.mp3" |
|
self.audio_text = None |
|
|
|
def rag_query(self, query: str) -> str: |
|
""" |
|
使用RAG系统查询知识库 |
|
""" |
|
response = self.query_engine.query(query) |
|
return str(response) |
|
|
|
def text_to_speech(self, text: str) -> str: |
|
""" |
|
将文本转换为语音 |
|
""" |
|
if not self.audio_text is None: |
|
print(f"文本已转为语音: {self.audio_text}") |
|
return |
|
try: |
|
client = OpenAI( api_key = self.__api_key, base_url="https://api.siliconflow.cn/v1") |
|
|
|
with client.audio.speech.with_streaming_response.create( |
|
model="fishaudio/fish-speech-1.5", |
|
voice="fishaudio/fish-speech-1.5:benjamin", |
|
|
|
input=f"{text}", |
|
response_format="mp3" |
|
) as response: |
|
response.stream_to_file(self.audio_save_file) |
|
|
|
if response.status_code == 200: |
|
self.audio_text = text |
|
print(f"文本已转为语音: {self.audio_save_file}") |
|
return f"文本转语音已完成。" |
|
else: |
|
print("文本转语音失败,状态码:", response.status_code) |
|
except Exception as e: |
|
return f"文本转语音时出错: {str(e)}" |
|
|
|
def generate_image(self, prompt: str) -> str: |
|
""" |
|
使用API生成图像 |
|
""" |
|
if not self.image_url is None: |
|
print(f"图像已生成: {self.image_url}") |
|
return |
|
try: |
|
|
|
optimized_prompt = self.prompt_engineer.optimize_image_prompt(prompt) |
|
print(f"优化后的图像提示词: {optimized_prompt}") |
|
|
|
|
|
url = "https://api.siliconflow.cn/v1/images/generations" |
|
payload = { |
|
"model": "stabilityai/stable-diffusion-3-5-large", |
|
"prompt": f"{optimized_prompt}", |
|
"negative_prompt": "<string>", |
|
"image_size": "1024x1024", |
|
"batch_size": 1, |
|
"seed": 4999999999, |
|
"num_inference_steps": 20, |
|
"guidance_scale": 7.5, |
|
"prompt_enhancement": False |
|
} |
|
|
|
headers = { |
|
"Authorization": f"Bearer {self.__api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
response = requests.request("POST", url, json=payload, headers=headers) |
|
if response.status_code == 200: |
|
data = response.json() |
|
self.image_url = data['data'][0]['url'] |
|
print(f"图像已生成: {self.image_url}") |
|
return f"图像已生成。" |
|
|
|
else: |
|
print("生成图像失败,状态码:", response.status_code) |
|
|
|
except Exception as e: |
|
return f"生成图像时出错: {str(e)}" |
|
|
|
def chat(self, user_input: str) -> dict: |
|
""" |
|
处理用户输入并返回适当的响应 |
|
""" |
|
|
|
prompt = f""" |
|
用户输入: {user_input} |
|
|
|
请根据以下规则处理这个请求: |
|
1. 如果是知识相关的问题,使用rag_tool查询知识库 |
|
2. 如果用户要求语音输出,使用tts_tool转换文本 |
|
3. 如果用户要求生成图像,使用image_tool生成 |
|
|
|
根据需求请选择合适的工具并执行操作,可能需要多个工具。 |
|
""" |
|
self.image_url = None |
|
self.audio_text = None |
|
response = self.agent.chat(prompt) |
|
response_dict = {"response": str(response), "image_url": self.image_url, "audio_text": self.audio_text } |
|
return response_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
wulewule_assistant = load_wulewule_agent() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
|
|
|
|
with st.sidebar: |
|
st.markdown("## 悟了悟了💡") |
|
logo_path = "assets/sd_wulewule.webp" |
|
if os.path.exists(logo_path): |
|
image = Image.open(logo_path) |
|
st.image(image, caption='wulewule') |
|
"[InternLM](https://github.com/InternLM)" |
|
"[悟了悟了](https://github.com/xzyun2011/wulewule.git)" |
|
|
|
|
|
st.title("悟了悟了:黑神话悟空AI助手🐒") |
|
|
|
|
|
for msg in st.session_state.messages: |
|
st.chat_message("user").write(msg["user"]) |
|
assistant_res = msg["assistant"] |
|
if isinstance(assistant_res, str): |
|
st.chat_message("assistant").write(assistant_res) |
|
elif isinstance(assistant_res, dict): |
|
image_url = assistant_res["image_url"] |
|
audio_text = assistant_res["audio_text"] |
|
st.chat_message("assistant").write(assistant_res["response"]) |
|
if image_url: |
|
|
|
st.image( image_url, width=256 ) |
|
if audio_text: |
|
|
|
st.audio("audio.mp3") |
|
st.write(f"语音内容为: {audio_text}") |
|
|
|
|
|
|
|
if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"): |
|
|
|
st.chat_message("user").write(prompt) |
|
|
|
full_answer = "" |
|
with st.chat_message('robot'): |
|
message_placeholder = st.empty() |
|
response_dict = wulewule_assistant.chat(prompt) |
|
image_url = response_dict["image_url"] |
|
audio_text = response_dict["audio_text"] |
|
for cur_response in response_dict["response"]: |
|
full_answer += cur_response |
|
|
|
message_placeholder.markdown(full_answer + '▌') |
|
message_placeholder.markdown(full_answer) |
|
|
|
st.session_state.messages.append({"user": prompt, "assistant": response_dict}) |
|
if image_url: |
|
|
|
st.image( image_url, width=256 ) |
|
|
|
if audio_text: |
|
|
|
st.audio("audio.mp3") |
|
st.write(f"语音内容为: {audio_text}") |
|
|