Spaces:
Paused
Paused
import os | |
from typing import List | |
import chainlit as cl | |
from llama_index.callbacks.base import CallbackManager | |
from llama_index import ( | |
ServiceContext, | |
StorageContext, | |
load_index_from_storage, | |
) | |
from llama_index.llms import OpenAI | |
from llama_index.postprocessor.cohere_rerank import CohereRerank | |
from llama_index.tools import QueryEngineTool, ToolMetadata | |
from llama_index.query_engine import SubQuestionQueryEngine | |
from llama_index.embeddings import HuggingFaceEmbedding | |
from chainlit.types import AskFileResponse | |
from llama_index import download_loader | |
from llama_index import VectorStoreIndex | |
def process_file(file: AskFileResponse): | |
import tempfile | |
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile: | |
with open(tempfile.name, "wb") as f: | |
f.write(file.content) | |
PDFReader = download_loader("PDFReader") | |
loader = PDFReader() | |
documents = loader.load_data(tempfile.name) | |
return documents | |
async def on_chat_start(): | |
files = None | |
# Wait for the user to upload a file | |
while files == None: | |
files = await cl.AskFileMessage( | |
content="Please upload a PDF file to begin!", | |
accept=["application/pdf"], | |
max_size_mb=20, | |
timeout=180, | |
).send() | |
file = files[0] | |
msg = cl.Message( | |
content=f"Processing `{file.name}`...", disable_human_feedback=True | |
) | |
await msg.send() | |
# load the file | |
documents = process_file(file) | |
context = ServiceContext.from_defaults( | |
embed_model=HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned") | |
) | |
index = VectorStoreIndex.from_documents( | |
documents=documents, context=context, show_progress=True | |
) | |
llm = OpenAI(model="gpt-4-1106-preview", temperature=0) | |
embed_model = HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned") | |
service_context = ServiceContext.from_defaults( | |
embed_model=embed_model, | |
llm=llm, | |
) | |
cohere_rerank = CohereRerank(top_n=5) | |
query_engine = index.as_query_engine( | |
similarity_top_k=10, | |
node_postprocessors=[cohere_rerank], | |
service_context=service_context, | |
) | |
query_engine_tools = [ | |
QueryEngineTool( | |
query_engine=query_engine, | |
metadata=ToolMetadata( | |
name="mit_theses", | |
description="A collection of MIT theses.", | |
), | |
), | |
] | |
query_engine = SubQuestionQueryEngine.from_defaults( | |
query_engine_tools=query_engine_tools, | |
service_context=service_context, | |
) | |
cl.user_session.set("query_engine", query_engine) | |
async def main(message: cl.Message): | |
query_engine = cl.user_session.get("query_engine") | |
response = await cl.make_async(query_engine.query)(message.content) | |
response_message = cl.Message(content=str(response)) | |
await response_message.send() | |