Spaces:
Running
Running
""" | |
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA | |
""" | |
# 导入必要的库与模块 | |
import json | |
import os | |
import textwrap | |
import requests | |
from dotenv import load_dotenv | |
from langchain.chat_models import ChatOpenAI | |
from langchain.document_loaders import TextLoader | |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, TensorflowHubEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import Weaviate | |
from weaviate import Client | |
from weaviate.embedded import EmbeddedOptions | |
from zhipuai import ZhipuAI | |
from openai import AzureOpenAI, OpenAI | |
# 环境设置与文档下载 | |
load_dotenv() # 加载环境变量 | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥 | |
MIMIMAX_API_KEY = os.getenv("MIMIMAX_API_KEY") | |
MIMIMAX_GROUP_ID = os.getenv("MIMIMAX_GROUP_ID") | |
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY") | |
KIMI_OPENAI_API_KEY = os.getenv("KIMI_OPENAI_API_KEY") | |
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_OPENAI_API_KEY") | |
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY") | |
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") | |
# 确保 OPENAI_API_KEY 被正确设置 | |
if not OPENAI_API_KEY: | |
raise ValueError("OpenAI API Key not found in the environment variables.") | |
# 文档加载与分割 | |
def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50): | |
"""加载文档并分割成小块""" | |
loader = TextLoader(file_path) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
chunks = text_splitter.split_documents(documents) | |
return chunks | |
# 向量存储建立 | |
def create_vector_store(chunks, model="OpenAI"): | |
"""将文档块转换为向量并存储到 Weaviate 中""" | |
client = Client(embedded_options=EmbeddedOptions()) | |
if model == "OpenAI": | |
embedding_model = OpenAIEmbeddings() | |
elif model == "HuggingFace": | |
embedding_model = HuggingFaceEmbeddings() | |
elif model == "TensorflowHub": | |
embedding_model = TensorflowHubEmbeddings() | |
else: | |
raise ValueError(f"Unsupported embedding model: {model}") | |
vectorstore = Weaviate.from_documents( | |
client=client, | |
documents=chunks, | |
embedding=embedding_model, | |
by_text=False | |
) | |
return vectorstore | |
def get_retriever(vectorstore, k=4): | |
return vectorstore.as_retriever(search_kwargs={'k': k}) | |
def setup_rag_chain(model_name="kimi", temperature=0): | |
"""设置检索增强生成流程""" | |
if model_name.startswith("gpt"): | |
# 如果是以gpt开头的模型,使用原来的逻辑 | |
prompt_template = """ | |
您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的**上下文信息**来形成回答,并适当参照**对话历史**。 | |
如果**上下文信息**与**问题**无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。 | |
**对话历史**: {conversation_history} | |
**问题**: {question} | |
**上下文信息**: {context} | |
**回答**: | |
""" | |
prompt = ChatPromptTemplate.from_template(prompt_template) | |
llm = ChatOpenAI(model_name=model_name, temperature=temperature) | |
# 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/ | |
rag_chain = ( | |
prompt | |
| llm | |
| StrOutputParser() | |
) | |
else: | |
# 如果不是以gpt开头的模型,返回None | |
rag_chain = None | |
return rag_chain | |
# 执行查询并打印结果 | |
def execute_query(retriever, rag_chain, query, model_name="kimi", temperature=0): | |
""" | |
执行查询并返回结果及检索到的文档块 | |
参数: | |
retriever: 文档检索器对象 | |
rag_chain: 检索增强生成链对象,如果为None则不使用RAG链 | |
query: 查询问题 | |
model_name: 使用的语言模型名称,默认为"gpt-4" | |
temperature: 生成温度,默认为0 | |
返回: | |
retrieved_documents: 检索到的文档块列表 | |
response_text: 生成的回答文本 | |
""" | |
if isinstance(query, list): | |
[conversation_history, question] = query | |
else: | |
conversation_history = '' | |
question = query | |
# 使用检索器检索相关文档块 | |
retrieved_documents = retriever.invoke(question) | |
if rag_chain is not None: | |
# 如果有RAG链,则使用RAG链生成回答 | |
rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": question}) | |
response_text = rag_chain_response | |
else: | |
prompt_template = """ | |
【对话历史】: {conversation_history} | |
【上下文信息】: {context} | |
您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的【上下文信息】来形成回答,并适当参照【对话历史】。 | |
如果【上下文信息】与【问题】无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。 | |
【问题】: {question} | |
【回答】: | |
""" | |
context = '\n'.join( | |
[retrieved_documents[i].page_content for i in range(len(retrieved_documents))]) | |
prompt = prompt_template.format(conversation_history=conversation_history, question=question, context=context) | |
response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt) | |
return retrieved_documents, response_text | |
def execute_query_no_rag(model_name="kimi", temperature=0, query=""): | |
"""执行无 RAG 链的查询""" | |
if model_name.startswith("gpt"): | |
# 如果是以gpt开头的模型,使用原来的逻辑 | |
llm = ChatOpenAI(model_name=model_name, temperature=temperature) | |
response = llm.invoke(query) | |
return response.content | |
elif model_name.startswith("azure_gpt"): | |
client = AzureOpenAI( | |
azure_endpoint=AZURE_ENDPOINT, | |
api_key=AZURE_OPENAI_KEY, | |
api_version="2024-02-15-preview" | |
) | |
message_text = [{"role": "user", "content": query}, ] | |
completion = client.chat.completions.create( | |
model=model_name[6:], # model_name = 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo' | |
messages=message_text, | |
temperature=temperature, | |
top_p=0.95, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=None | |
) | |
return completion.choices[0].message.content | |
elif model_name == 'abab6-chat': | |
# 如果是'abab6-chat'模型,使用专门的API调用方式 | |
url = "https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId=" + MIMIMAX_GROUP_ID | |
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + MIMIMAX_API_KEY} | |
payload = { | |
"bot_setting": [ | |
{ | |
"bot_name": "MM智能助理", | |
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。", | |
} | |
], | |
"messages": [{"sender_type": "USER", "sender_name": "小明", "text": query}], | |
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"}, | |
"model": model_name, | |
"tokens_to_generate": 1034, | |
"temperature": temperature, | |
"top_p": 0.9, | |
} | |
response = requests.request("POST", url, headers=headers, json=payload) | |
# 将 JSON 字符串解析为字典 | |
response_dict = json.loads(response.text) | |
# 提取 'reply' 键对应的值 | |
return response_dict['reply'] | |
elif model_name == 'glm-4': | |
# 如果是'glm-4'模型,使用专门的API调用方式 | |
client = ZhipuAI(api_key=ZHIPUAI_API_KEY) # 填写您自己的APIKey | |
response = client.chat.completions.create( | |
model=model_name, # 填写需要调用的模型名称 | |
messages=[{"role": "user", "content": query}] | |
) | |
return response.choices[0].message.content | |
elif model_name == 'kimi': | |
# 如果是'kimi'模型,使用专门的API调用方式 | |
client = OpenAI( | |
api_key=KIMI_OPENAI_API_KEY, | |
base_url="https://api.moonshot.cn/v1", | |
) | |
messages = [ | |
{ | |
"role": "system", | |
"content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。", | |
}, | |
{"role": "user", | |
"content": query}, | |
] | |
completion = client.chat.completions.create( | |
# model="moonshot-v1-128k", | |
model="moonshot-v1-32k", | |
messages=messages, | |
temperature=temperature, | |
top_p=1.0, | |
n=1, # 为每条输入消息生成多少个结果 | |
stream=False # 流式输出 | |
) | |
return completion.choices[0].message.content | |
elif model_name == 'deepseek': | |
# 如果是'deepseek'模型,使用专门的API调用方式 | |
client = OpenAI( | |
api_key="sk-c4a8fe52693a4aaab64e648c42f40be6", | |
base_url="https://api.deepseek.com" | |
) | |
response = client.chat.completions.create( | |
model="deepseek-chat", # deepseek-coder | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant"}, | |
{"role": "user", "content": query}, | |
], | |
# max_tokens=4096, | |
# max_tokens=32000, | |
temperature=0.7, | |
stream=False, | |
frequency_penalty=0, | |
presence_penalty=0, | |
top_p=1, | |
logprobs=False, | |
) | |
return response.choices[0].message.content | |
else: | |
# 如果模型不支持,抛出异常 | |
raise ValueError(f"Unsupported model: {model_name}") | |
if __name__ == "__main__": | |
# 假设文档已存在于本地 | |
file_path = './documents/LightZero_README_zh.md' | |
# model_name = "glm-4" # model_name=['abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'] | |
# model_name = 'azure_gpt-4' | |
# model_name = 'kimi' | |
model_name = 'deepseek' | |
temperature = 0.01 | |
embedding_model = 'OpenAI' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI'] | |
# 加载和分割文档 | |
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500) | |
# 创建向量存储 | |
vectorstore = create_vector_store(chunks, model=embedding_model) | |
retriever = get_retriever(vectorstore, k=5) | |
# 设置 RAG 流程 | |
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature) | |
# 提出问题并获取答案 | |
query = ("请回答下面的问题:(1)请简要介绍一下 LightZero。(2)请详细介绍 LightZero 的框架结构。 (3)请给出安装 LightZero,运行他们的示例代码的详细步骤。(4)- 请问 LightZero 具体支持什么任务(tasks/environments)? (5)请问 LightZero 具体支持什么算法?(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行? (7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。(11)请问对这个仓库提出详细的改进建议") | |
""" | |
(1)请简要介绍一下 LightZero。 | |
(2)请详细介绍 LightZero 的框架结构。 | |
(3)请给出安装 LightZero,运行他们的示例代码的详细步骤 。 | |
(4)请问 LightZero 具体支持什么任务(tasks/environments)? | |
(5)请问 LightZero 具体支持什么算法? | |
(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行? | |
(7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗? | |
(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗? | |
(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢? | |
(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。 | |
(11)请问对这个仓库提出详细的改进建议。 | |
""" | |
# 使用 RAG 链获取参考的文档与答案 | |
retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query, model_name=model_name, | |
temperature=temperature) | |
# 不使用 RAG 链获取答案 | |
result_without_rag = execute_query_no_rag(model_name=model_name, query=query, temperature=temperature) | |
# 打印并对比两种方法的结果 | |
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整 | |
wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80) | |
wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80) | |
context = '\n'.join( | |
[f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))]) | |
# 打印自动分段后的文本 | |
print("=" * 40) | |
print(f"我的问题是:\n{query}") | |
print("=" * 40) | |
print(f"Result with RAG:\n{wrapped_result_with_rag}\n检索得到的context是: \n{context}") | |
print("=" * 40) | |
print(f"Result without RAG:\n{wrapped_result_without_rag}") | |
print("=" * 40) | |