acpotts commited on
Commit
d891fa6
·
verified ·
1 Parent(s): 2823e81

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -133
app.py CHANGED
@@ -1,134 +1,55 @@
1
- import os
2
- from typing import List
3
- from chainlit.types import AskFileResponse
4
- from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
5
- from aimakerspace.openai_utils.prompts import (
6
- UserRolePrompt,
7
- SystemRolePrompt,
8
- AssistantRolePrompt,
9
- )
10
- from aimakerspace.openai_utils.embedding import EmbeddingModel
11
- from aimakerspace.vectordatabase import VectorDatabase
12
- from aimakerspace.openai_utils.chatmodel import ChatOpenAI
13
- import chainlit as cl
14
- from langchain_text_splitters import RecursiveCharacterTextSplitter
15
- # from langchain_experimental.text_splitter import SemanticChunker
16
- # from langchain_openai.embeddings import OpenAIEmbeddings
17
-
18
- system_template = """\
19
- Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
20
- system_role_prompt = SystemRolePrompt(system_template)
21
-
22
- user_prompt_template = """\
23
- Context:
24
- {context}
25
- Question:
26
- {question}
27
- """
28
- user_role_prompt = UserRolePrompt(user_prompt_template)
29
-
30
- class RetrievalAugmentedQAPipeline:
31
- def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
32
- self.llm = llm
33
- self.vector_db_retriever = vector_db_retriever
34
-
35
- async def arun_pipeline(self, user_query: str):
36
- context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
37
-
38
- context_prompt = ""
39
- for context in context_list:
40
- context_prompt += context[0] + "\n"
41
-
42
- formatted_system_prompt = system_role_prompt.create_message()
43
-
44
- formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
45
-
46
- async def generate_response():
47
- async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
48
- yield chunk
49
-
50
- return {"response": generate_response(), "context": context_list}
51
-
52
- text_splitter = RecursiveCharacterTextSplitter()
53
-
54
-
55
- def process_text_file(file: AskFileResponse):
56
- import tempfile
57
- from langchain_community.document_loaders.pdf import PyPDFLoader
58
-
59
- with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=file.name) as temp_file:
60
- temp_file_path = temp_file.name
61
-
62
- with open(temp_file_path, "wb") as f:
63
- f.write(file.content)
64
-
65
- if file.type == 'text/plain':
66
- text_loader = TextFileLoader(temp_file_path)
67
- documents = text_loader.load_documents()
68
- elif file.type == 'application/pdf':
69
- pdf_loader = PyPDFLoader(temp_file_path)
70
- documents = pdf_loader.load()
71
- else:
72
- raise ValueError("Provide a .txt or .pdf file")
73
- texts = [x.page_content for x in text_splitter.transform_documents(documents)]
74
- # texts = [x.page_content for x in text_splitter.split_documents(documents)]
75
- return texts
76
-
77
-
78
-
79
- @cl.on_chat_start
80
- async def on_chat_start():
81
- files = None
82
-
83
- # Wait for the user to upload a file
84
- while files == None:
85
- files = await cl.AskFileMessage(
86
- content="Please upload a Text file or a PDF to begin!",
87
- accept=["text/plain", "application/pdf"],
88
- max_size_mb=12,
89
- timeout=180,
90
- max_files=10
91
- ).send()
92
- vector_db = VectorDatabase()
93
- for file in files:
94
-
95
- msg = cl.Message(
96
- content=f"Processing `{file.name}`...", disable_human_feedback=True
97
- )
98
- await msg.send()
99
-
100
- # load the file
101
- texts = process_text_file(file)
102
-
103
- print(f"Processing {len(texts)} text chunks")
104
-
105
- # Create a dict vector store
106
-
107
- vector_db = await vector_db.abuild_from_list(texts)
108
-
109
- chat_openai = ChatOpenAI()
110
-
111
- # Create a chain
112
- retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
113
- vector_db_retriever=vector_db,
114
- llm=chat_openai
115
- )
116
-
117
- # Let the user know that the system is ready
118
- msg.content = f"Processing `{file.name}` done. You can now ask questions!"
119
- await msg.update()
120
-
121
- cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
122
-
123
-
124
- @cl.on_message
125
- async def main(message):
126
- chain = cl.user_session.get("chain")
127
-
128
- msg = cl.Message(content="")
129
- result = await chain.arun_pipeline(message.content)
130
-
131
- async for stream_resp in result["response"]:
132
- await msg.stream_token(stream_resp)
133
-
134
  await msg.send()
 
1
+ ### Import Section ###
2
+ """
3
+ IMPORTS HERE
4
+ """
5
+ # Example Imports (adjust based on actual needs)
6
+ import chainlit as cl
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.chains import ConversationChain
9
+ from langchain.prompts import ChatPromptTemplate
10
+ from langchain.schema import StrOutputParser
11
+ from langchain.schema.runnable import Runnable
12
+ from langchain.schema.runnable.config import RunnableConfig
13
+ from typing import cast
14
+
15
+ ### Global Section ###
16
+ """
17
+ GLOBAL CODE HERE
18
+ """
19
+ # Initialize a language model or chain globally
20
+ llm = ChatOpenAI(temperature=0.9)
21
+ conversation_chain = ConversationChain(llm=llm)
22
+
23
+ # Any global variables like API keys, configurations, etc.
24
+ # API_KEY = "your_api_key_here"
25
+
26
+
27
+ ### On Chat Start (Session Start) Section ###
28
+ @cl.on_chat_start
29
+ async def on_chat_start():
30
+ """ SESSION SPECIFIC CODE HERE """
31
+ await cl.Message(content="Welcome! How can I assist you today?").send()
32
+
33
+ ### Rename Chains ###
34
+ @cl.author_rename
35
+ def rename(orig_author: str):
36
+ if orig_author == "user":
37
+ return "You"
38
+ elif orig_author == "system":
39
+ return "Assistant"
40
+ return orig_author
41
+
42
+ ### On Message Section ###
43
+ @cl.on_message
44
+ async def on_message(message: cl.Message):
45
+ runnable = cast(Runnable, cl.user_session.get("runnable"))
46
+
47
+ msg = cl.Message(content="")
48
+
49
+ async for chunk in runnable.astream(
50
+ {"question": message.content},
51
+ config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
52
+ ):
53
+ await msg.stream_token(chunk)
54
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  await msg.send()