Spaces:
Runtime error
Runtime error
# 长文本总结 | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain import OpenAI | |
from langchain import PromptTemplate | |
from langchain.docstore.document import Document as LangDoc | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100) | |
import openai | |
from openai.error import RateLimitError | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate | |
) | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
from langchain.output_parsers import RegexParser | |
openai.api_key = "" | |
summary_prompt = ( | |
"总结以下会议记录中所探讨的主要话题,忽略细节\n" | |
"会议记录:{text}\n" | |
"在输出时,请注意以下几点:\n" | |
"1. 输出内容中避免口语化内容\n" | |
"2. 每个话题用序号标注\n" | |
"3. 不输出无关信息" | |
) | |
qa_prompt = """ | |
结合下面的信息,用中文回答最后的问题。如果你不知道答案,说“我不知道”,不可以编造答案。 | |
除了回答问题外,还需要输出一个分数,表示你对这个问题的回答的自信程度。分数越高,你越自信。按照以下的格式输出: | |
回答:[回答内容] | |
分数:[0到100间的数字] | |
开始回答: | |
{context} | |
问题:{question} | |
""" | |
def get_chatgpt_reply(query, context=[]): | |
context += [query] | |
llm_chat = ChatOpenAI(model_name="gpt-3.5-turbo", max_tokens=2000, temperature=0.3) | |
embeddings = OpenAIEmbeddings() | |
docsearch = Chroma(persist_directory="./VectorDB", embedding_function=embeddings) | |
output_parser = RegexParser( | |
regex=r"(.*)\n*分数:([0-9]*).*", | |
output_keys=["answer", "score"], | |
) | |
PROMPT = PromptTemplate( | |
template=qa_prompt, input_variables=["context", "question"], output_parser=output_parser | |
) | |
chain_type_kwargs = {"prompt": PROMPT} | |
qa = RetrievalQA.from_chain_type(llm_chat, chain_type="map_rerank", retriever=docsearch.as_retriever(), chain_type_kwargs=chain_type_kwargs) | |
result = qa.run(query) | |
context += [result] | |
responses = [(u,b) for u,b in zip(context[::2], context[1::2])] | |
return responses, context | |
def get_chatgpt_summary(content): | |
texts = text_splitter.split_text(content) | |
docs = [LangDoc(page_content=t) for t in texts] | |
llm_summary = OpenAI(model_name="gpt-3.5-turbo", max_tokens=300, temperature=0.2) | |
each_round_template = PromptTemplate(input_variables=["text"], template=summary_prompt) | |
chain_summary = load_summarize_chain(llm_summary, chain_type="stuff", prompt=each_round_template) | |
summary = "\n*******\n".join([chain_summary.run([doc]) for doc in docs]) | |
return summary | |
import gradio as gr | |
from docx import Document | |
import os | |
def upload_file(file): | |
doc = Document(file.name) | |
content = "" | |
for para in doc.paragraphs: | |
content += para.text | |
content += '\n' | |
texts = text_splitter.split_text(content) | |
docs = [LangDoc(page_content=t) for t in texts] | |
embeddings = OpenAIEmbeddings() | |
docsearch = Chroma.from_documents(docs, embeddings, persist_directory="./VectorDB") | |
docsearch.persist() | |
return content | |
def set_api_key(api_key): | |
openai.api_key = api_key | |
os.environ["OPENAI_API_KEY"] = api_key | |
return None | |
with gr.Blocks(theme=gr.themes.Default(text_size='lg', radius_size='sm')) as demo: | |
with gr.Column(): | |
# 产品介绍 | |
title = gr.Markdown("# <center>ChatMeeting</center>") | |
desc = gr.Markdown("<center>让AI帮你整理会议纪要\n\n支持.docx文件</center>") | |
with gr.Column(): | |
# api key | |
api_input = gr.Textbox(label="API Key", placeholder="请输入API Key", type="password") | |
api_btn = gr.Button(value="设置") | |
api_btn.click(fn=set_api_key, inputs=api_input, outputs=None) | |
with gr.Row(): | |
with gr.Column(): | |
# 文件上传 | |
file_input = gr.File(file_types=[".docx"], label="原始文稿", interactive=True) | |
upload_btn = gr.Button(value="上传") | |
# 文字展示 | |
with gr.Tab("原文"): | |
# 原文 | |
content_box = gr.Textbox(label="文稿内容") | |
with gr.Tab("总结"): | |
# 总结 | |
summary_box = gr.Textbox(label="总结内容") | |
with gr.Column(): | |
# 对话交互 | |
chatbot = gr.Chatbot(label="对话内容").style(height=500) | |
state = gr.State([]) | |
txt = gr.Textbox(label="用户", placeholder="请输入内容") | |
with gr.Row(): | |
summary = gr.Button(value="一键总结") | |
clear = gr.Button(value="清空") | |
summary.click(fn=get_chatgpt_summary, inputs=content_box, outputs=summary_box) | |
txt.submit(get_chatgpt_reply, [txt, state], [chatbot, state]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
# 上传文件,langchain解析 | |
upload_btn.click(fn=upload_file, inputs=file_input, outputs=content_box) | |
demo.launch() |