Spaces:
Paused
Paused
File size: 2,916 Bytes
5795fcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from typing import List
import chainlit as cl
from llama_index.callbacks.base import CallbackManager
from llama_index import (
ServiceContext,
StorageContext,
load_index_from_storage,
)
from llama_index.llms import OpenAI
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.query_engine import SubQuestionQueryEngine
from llama_index.embeddings import HuggingFaceEmbedding
from chainlit.types import AskFileResponse
from llama_index import download_loader
print("Loading Storage Context...")
storage_context = StorageContext.from_defaults(persist_dir="index/")
print("Loading Index...")
index = load_index_from_storage(storage_context)
def process_file(file: AskFileResponse):
import tempfile
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile:
with open(tempfile.name, "wb") as f:
f.write(file.content)
PDFReader = download_loader("PDFReader")
loader = PDFReader()
documents = loader.load_data(tempfile.name)
return documents
@cl.on_chat_start
async def on_chat_start():
files = None
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a PDF file to begin!",
accept=["application/pdf"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# load the file
documents = process_file(file)
index = await cl.make_async(index.add_documents)(documents)
llm = OpenAI(model="gpt-4-1106-preview", temperature=0)
embed_model = HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned")
service_context = ServiceContext.from_defaults(
embed_model=embed_model,
llm=llm,
)
cohere_rerank = CohereRerank(top_n=5)
query_engine = index.as_query_engine(
similarity_top_k=10,
node_postprocessors=[cohere_rerank],
service_context=service_context,
)
query_engine_tools = [
QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="mit_theses",
description="A collection of MIT theses.",
),
),
]
query_engine = SubQuestionQueryEngine.from_defaults(
query_engine_tools=query_engine_tools,
service_context=service_context,
)
cl.user_session.set("query_engine", query_engine)
@cl.on_message
async def main(message: cl.Message):
query_engine = cl.user_session.get("query_engine")
response = await cl.make_async(query_engine.query)(message.content)
response_message = cl.Message(content=str(response))
await response_message.send()
|