Spaces:
Runtime error
Runtime error
Some refactoring the memory inerface
Browse files- megabots/__init__.py +6 -253
- megabots/bot.py +192 -0
- megabots/memory.py +86 -0
- megabots/utils.py +39 -0
- megabots/{vectorstores.py → vectorstore.py} +6 -4
- tests/test_memory.py +42 -0
megabots/__init__.py
CHANGED
@@ -1,256 +1,9 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from
|
5 |
-
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
6 |
-
from langchain.vectorstores.faiss import FAISS
|
7 |
-
import gradio as gr
|
8 |
-
from fastapi import FastAPI
|
9 |
-
import pickle
|
10 |
-
import os
|
11 |
-
from dotenv import load_dotenv
|
12 |
-
from langchain.prompts import PromptTemplate
|
13 |
-
from langchain.chains.question_answering import load_qa_chain
|
14 |
-
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
15 |
-
from langchain.document_loaders import DirectoryLoader
|
16 |
-
from megabots.vectorstores import VectorStore, vectorstore
|
17 |
-
from langchain.memory import (
|
18 |
-
ConversationBufferMemory,
|
19 |
-
ConversationBufferWindowMemory,
|
20 |
-
ConversationSummaryMemory,
|
21 |
-
ConversationSummaryBufferMemory,
|
22 |
-
)
|
23 |
-
|
24 |
-
load_dotenv()
|
25 |
-
|
26 |
-
|
27 |
-
class Bot:
|
28 |
-
def __init__(
|
29 |
-
self,
|
30 |
-
model: str | None = None,
|
31 |
-
prompt_template: str | None = None,
|
32 |
-
prompt_variables: list[str] | None = None,
|
33 |
-
index: str | None = None,
|
34 |
-
sources: bool | None = False,
|
35 |
-
vectorstore: VectorStore | None = None,
|
36 |
-
memory: str | None = None,
|
37 |
-
memory_window: int = 3,
|
38 |
-
verbose: bool = False,
|
39 |
-
temperature: int = 0,
|
40 |
-
):
|
41 |
-
self.select_model(model, temperature)
|
42 |
-
self.create_loader(index)
|
43 |
-
self.load_or_create_index(index, vectorstore)
|
44 |
-
|
45 |
-
# Load the question-answering chain for the selected model
|
46 |
-
self.chain = self.create_chain(
|
47 |
-
prompt_template, prompt_variables, sources=sources, verbose=verbose
|
48 |
-
)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
def create_chain(
|
53 |
-
self,
|
54 |
-
prompt_template: str | None = None,
|
55 |
-
prompt_variables: list[str] | None = None,
|
56 |
-
sources: bool | None = False,
|
57 |
-
verbose: bool = False,
|
58 |
-
):
|
59 |
-
prompt = (
|
60 |
-
PromptTemplate(template=prompt_template, input_variables=prompt_variables)
|
61 |
-
if prompt_template is not None and prompt_variables is not None
|
62 |
-
else QA_PROMPT
|
63 |
-
)
|
64 |
-
# TODO: Changing the prompt here is not working. Leave it as is for now.
|
65 |
-
# Reference: https://github.com/hwchase17/langchain/issues/2858
|
66 |
-
if sources:
|
67 |
-
return load_qa_with_sources_chain(
|
68 |
-
self.llm, chain_type="stuff", verbose=verbose
|
69 |
-
)
|
70 |
-
return load_qa_chain(
|
71 |
-
self.llm, chain_type="stuff", verbose=verbose, prompt=prompt
|
72 |
-
)
|
73 |
-
|
74 |
-
def select_model(self, model: str | None, temperature: float):
|
75 |
-
# Select and set the appropriate model based on the provided input
|
76 |
-
if model is None or model == "gpt-3.5-turbo":
|
77 |
-
print("Using model: gpt-3.5-turbo")
|
78 |
-
self.llm = ChatOpenAI(temperature=temperature)
|
79 |
-
|
80 |
-
if model == "text-davinci-003":
|
81 |
-
print("Using model: text-davinci-003")
|
82 |
-
self.llm = OpenAI(temperature=temperature)
|
83 |
-
|
84 |
-
def create_loader(self, index: str | None):
|
85 |
-
# Create a loader based on the provided directory (either local or S3)
|
86 |
-
if index is None:
|
87 |
-
raise RuntimeError(
|
88 |
-
"""
|
89 |
-
Impossible to find a valid index.
|
90 |
-
Either provide a valid path to a pickle file or a directory.
|
91 |
-
"""
|
92 |
-
)
|
93 |
-
self.loader = DirectoryLoader(index, recursive=True)
|
94 |
-
|
95 |
-
def load_or_create_index(self, index: str, vectorstore: VectorStore | None = None):
|
96 |
-
# Load an existing index from disk or create a new one if not available
|
97 |
-
if vectorstore is not None:
|
98 |
-
self.search_index = vectorstore.client.from_documents(
|
99 |
-
self.loader.load_and_split(),
|
100 |
-
OpenAIEmbeddings(),
|
101 |
-
connection_args={"host": vectorstore.host, "port": vectorstore.port},
|
102 |
-
)
|
103 |
-
return
|
104 |
-
|
105 |
-
# Is pickle
|
106 |
-
if index is not None and "pkl" in index or "pickle" in index:
|
107 |
-
print("Loading path from pickle file: ", index, "...")
|
108 |
-
with open(index, "rb") as f:
|
109 |
-
self.search_index = pickle.load(f)
|
110 |
-
return
|
111 |
-
|
112 |
-
# Is directory
|
113 |
-
if index is not None and os.path.isdir(index):
|
114 |
-
print("Creating index...")
|
115 |
-
self.search_index = FAISS.from_documents(
|
116 |
-
self.loader.load_and_split(), OpenAIEmbeddings()
|
117 |
-
)
|
118 |
-
return
|
119 |
-
|
120 |
-
raise RuntimeError(
|
121 |
-
"""
|
122 |
-
Impossible to find a valid index.
|
123 |
-
Either provide a valid path to a pickle file or a directory.
|
124 |
-
"""
|
125 |
-
)
|
126 |
-
|
127 |
-
def save_index(self, index_path: str):
|
128 |
-
# Save the index to the specified path
|
129 |
-
with open(index_path, "wb") as f:
|
130 |
-
pickle.dump(self.search_index, f)
|
131 |
-
|
132 |
-
def ask(self, question: str, k=1) -> str:
|
133 |
-
# Retrieve the answer to the given question and return it
|
134 |
-
input_documents = self.search_index.similarity_search(question, k=k)
|
135 |
-
answer = self.chain.run(input_documents=input_documents, question=question)
|
136 |
-
return answer
|
137 |
-
|
138 |
-
|
139 |
-
SUPPORTED_TASKS = {
|
140 |
-
"qna-over-docs": {
|
141 |
-
"impl": Bot,
|
142 |
-
"default": {
|
143 |
-
"model": "gpt-3.5-turbo",
|
144 |
-
"temperature": 0,
|
145 |
-
"index": "./index",
|
146 |
-
},
|
147 |
-
}
|
148 |
-
}
|
149 |
|
150 |
-
SUPPORTED_MODELS = {}
|
151 |
|
152 |
-
|
153 |
-
"conversation-buffer-window": {
|
154 |
-
"impl": ConversationBufferWindowMemory,
|
155 |
-
"default": {"memory_window": 3},
|
156 |
-
},
|
157 |
-
"conversation-buffer": {
|
158 |
-
"impl": ConversationBufferMemory,
|
159 |
-
"default": {},
|
160 |
-
},
|
161 |
-
"conversation-summary": {
|
162 |
-
"impl": ConversationSummaryMemory,
|
163 |
-
"default": {},
|
164 |
-
"conversation-summary-buffer": {
|
165 |
-
"impl": ConversationSummaryBufferMemory,
|
166 |
-
"default": {
|
167 |
-
"max_token_limit":40
|
168 |
-
}
|
169 |
-
},
|
170 |
-
}
|
171 |
-
|
172 |
-
|
173 |
-
def bot(
|
174 |
-
task: str | None = None,
|
175 |
-
model: str | None = None,
|
176 |
-
index: str | None = None,
|
177 |
-
prompt_template: str | None = None,
|
178 |
-
prompt_variables: list[str] | None = None,
|
179 |
-
memory: str | None = None,
|
180 |
-
memory_window: int = 3,
|
181 |
-
verbose: bool = False,
|
182 |
-
temperature: int = 0,
|
183 |
-
**kwargs,
|
184 |
-
) -> Bot:
|
185 |
-
"""Instanciate a bot based on the provided task. Each supported tasks has it's own default sane defaults.
|
186 |
-
|
187 |
-
Args:
|
188 |
-
task (str | None, optional): The given task. Can be one of the SUPPORTED_TASKS.
|
189 |
-
model (str | None, optional): Model to be used. Can be one of the SUPPORTED_MODELS.
|
190 |
-
index (str | None, optional): Data that the model will load and store index info.
|
191 |
-
Can be either a local file path, a pickle file, or a url of a vector database.
|
192 |
-
By default it will look for a local directory called "files" in the current working directory.
|
193 |
-
prompt_template (str | None, optional): Prompt template to be used. Specify variables with {var_name}.
|
194 |
-
prompt_variables (list[str] | None, optional): Prompt variables to be used in the prompt template.
|
195 |
-
verbose (bool, optional): Verbocity. Defaults to False.
|
196 |
-
temperature (int, optional): Temperature. Defaults to 0.
|
197 |
-
|
198 |
-
Raises:
|
199 |
-
RuntimeError: _description_
|
200 |
-
ValueError: _description_
|
201 |
-
|
202 |
-
Returns:
|
203 |
-
Bot: Bot instance
|
204 |
-
"""
|
205 |
-
|
206 |
-
if task is None:
|
207 |
-
raise RuntimeError("Impossible to instantiate a bot without a task.")
|
208 |
-
if task not in SUPPORTED_TASKS:
|
209 |
-
raise ValueError(f"Task {task} is not supported.")
|
210 |
-
|
211 |
-
task_defaults = SUPPORTED_TASKS[task]["default"]
|
212 |
-
return SUPPORTED_TASKS[task]["impl"](
|
213 |
-
model=model or task_defaults["model"],
|
214 |
-
index=index or task_defaults["index"],
|
215 |
-
prompt_template=prompt_template,
|
216 |
-
prompt_variables=prompt_variables,
|
217 |
-
temperature=temperature,
|
218 |
-
verbose=verbose,
|
219 |
-
**kwargs,
|
220 |
-
)
|
221 |
-
|
222 |
-
|
223 |
-
def create_api(bot: Bot):
|
224 |
-
app = FastAPI()
|
225 |
-
|
226 |
-
@app.get("/v1/ask/{question}")
|
227 |
-
async def ask(question: str):
|
228 |
-
answer = bot.ask(question)
|
229 |
-
return {"answer": answer}
|
230 |
-
|
231 |
-
return app
|
232 |
-
|
233 |
-
|
234 |
-
def create_interface(bot_instance: Bot, examples: list[list[str]] = []):
|
235 |
-
with gr.Blocks() as interface:
|
236 |
-
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
|
237 |
-
msg = gr.Textbox(
|
238 |
-
show_label=False,
|
239 |
-
placeholder="Enter text and press enter, or upload an image",
|
240 |
-
).style(container=False)
|
241 |
-
clear = gr.Button("Clear")
|
242 |
-
|
243 |
-
def user(user_message, history):
|
244 |
-
return "", history + [[user_message, None]]
|
245 |
-
|
246 |
-
def bot(history):
|
247 |
-
response = bot_instance.ask(history[-1][0])
|
248 |
-
history[-1][1] = response
|
249 |
-
return history
|
250 |
-
|
251 |
-
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
252 |
-
bot, chatbot, chatbot
|
253 |
-
)
|
254 |
-
clear.click(lambda: None, None, chatbot, queue=False)
|
255 |
|
256 |
-
|
|
|
1 |
+
from megabots.vectorstore import VectorStore, vectorstore
|
2 |
+
from megabots.memory import Memory, memory
|
3 |
+
from megabots.bot import Bot, bot
|
4 |
+
from megabots.utils import create_api, create_interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
|
|
6 |
|
7 |
+
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
load_dotenv()
|
megabots/bot.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
from langchain.llms import OpenAI
|
3 |
+
from langchain.chat_models import ChatOpenAI
|
4 |
+
from langchain.embeddings import OpenAIEmbeddings
|
5 |
+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
6 |
+
from langchain.vectorstores.faiss import FAISS
|
7 |
+
import pickle
|
8 |
+
import os
|
9 |
+
from langchain.prompts import PromptTemplate
|
10 |
+
from langchain.chains.question_answering import load_qa_chain
|
11 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
12 |
+
from langchain.document_loaders import DirectoryLoader
|
13 |
+
from megabots.vectorstore import VectorStore
|
14 |
+
from megabots.memory import Memory
|
15 |
+
import megabots
|
16 |
+
|
17 |
+
|
18 |
+
class Bot:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model: str | None = None,
|
22 |
+
prompt_template: str | None = None,
|
23 |
+
prompt_variables: list[str] | None = None,
|
24 |
+
index: str | None = None,
|
25 |
+
sources: bool | None = False,
|
26 |
+
vectorstore: VectorStore | None = None,
|
27 |
+
memory: Memory | None = None,
|
28 |
+
verbose: bool = False,
|
29 |
+
temperature: int = 0,
|
30 |
+
):
|
31 |
+
self.select_model(model, temperature)
|
32 |
+
self.create_loader(index)
|
33 |
+
self.load_or_create_index(index, vectorstore)
|
34 |
+
self.vectorstore = vectorstore
|
35 |
+
self.memory = memory
|
36 |
+
# Load the question-answering chain for the selected model
|
37 |
+
self.chain = self.create_chain(
|
38 |
+
prompt_template, prompt_variables, sources=sources, verbose=verbose
|
39 |
+
)
|
40 |
+
|
41 |
+
def create_chain(
|
42 |
+
self,
|
43 |
+
prompt_template: str | None = None,
|
44 |
+
prompt_variables: list[str] | None = None,
|
45 |
+
sources: bool | None = False,
|
46 |
+
verbose: bool = False,
|
47 |
+
):
|
48 |
+
prompt = (
|
49 |
+
PromptTemplate(template=prompt_template, input_variables=prompt_variables)
|
50 |
+
if prompt_template is not None and prompt_variables is not None
|
51 |
+
else QA_PROMPT
|
52 |
+
)
|
53 |
+
# TODO: Changing the prompt here is not working. Leave it as is for now.
|
54 |
+
# Reference: https://github.com/hwchase17/langchain/issues/2858
|
55 |
+
if sources:
|
56 |
+
return load_qa_with_sources_chain(
|
57 |
+
self.llm, chain_type="stuff", verbose=verbose
|
58 |
+
)
|
59 |
+
return load_qa_chain(
|
60 |
+
self.llm, chain_type="stuff", verbose=verbose, prompt=prompt
|
61 |
+
)
|
62 |
+
|
63 |
+
def select_model(self, model: str | None, temperature: float):
|
64 |
+
# Select and set the appropriate model based on the provided input
|
65 |
+
if model is None or model == "gpt-3.5-turbo":
|
66 |
+
print("Using model: gpt-3.5-turbo")
|
67 |
+
self.llm = ChatOpenAI(temperature=temperature)
|
68 |
+
|
69 |
+
if model == "text-davinci-003":
|
70 |
+
print("Using model: text-davinci-003")
|
71 |
+
self.llm = OpenAI(temperature=temperature)
|
72 |
+
|
73 |
+
def create_loader(self, index: str | None):
|
74 |
+
# Create a loader based on the provided directory (either local or S3)
|
75 |
+
if index is None:
|
76 |
+
raise RuntimeError(
|
77 |
+
"""
|
78 |
+
Impossible to find a valid index.
|
79 |
+
Either provide a valid path to a pickle file or a directory.
|
80 |
+
"""
|
81 |
+
)
|
82 |
+
self.loader = DirectoryLoader(index, recursive=True)
|
83 |
+
|
84 |
+
def load_or_create_index(self, index: str, vectorstore: VectorStore | None = None):
|
85 |
+
# Load an existing index from disk or create a new one if not available
|
86 |
+
if vectorstore is not None:
|
87 |
+
self.search_index = vectorstore.client.from_documents(
|
88 |
+
self.loader.load_and_split(),
|
89 |
+
OpenAIEmbeddings(),
|
90 |
+
connection_args={"host": vectorstore.host, "port": vectorstore.port},
|
91 |
+
)
|
92 |
+
return
|
93 |
+
|
94 |
+
# Is pickle
|
95 |
+
if index is not None and "pkl" in index or "pickle" in index:
|
96 |
+
print("Loading path from pickle file: ", index, "...")
|
97 |
+
with open(index, "rb") as f:
|
98 |
+
self.search_index = pickle.load(f)
|
99 |
+
return
|
100 |
+
|
101 |
+
# Is directory
|
102 |
+
if index is not None and os.path.isdir(index):
|
103 |
+
print("Creating index...")
|
104 |
+
self.search_index = FAISS.from_documents(
|
105 |
+
self.loader.load_and_split(), OpenAIEmbeddings()
|
106 |
+
)
|
107 |
+
return
|
108 |
+
|
109 |
+
raise RuntimeError(
|
110 |
+
"""
|
111 |
+
Impossible to find a valid index.
|
112 |
+
Either provide a valid path to a pickle file or a directory.
|
113 |
+
"""
|
114 |
+
)
|
115 |
+
|
116 |
+
def save_index(self, index_path: str):
|
117 |
+
# Save the index to the specified path
|
118 |
+
with open(index_path, "wb") as f:
|
119 |
+
pickle.dump(self.search_index, f)
|
120 |
+
|
121 |
+
def ask(self, question: str, k=1) -> str:
|
122 |
+
# Retrieve the answer to the given question and return it
|
123 |
+
input_documents = self.search_index.similarity_search(question, k=k)
|
124 |
+
answer = self.chain.run(input_documents=input_documents, question=question)
|
125 |
+
return answer
|
126 |
+
|
127 |
+
|
128 |
+
SUPPORTED_TASKS = {
|
129 |
+
"qna-over-docs": {
|
130 |
+
"impl": Bot,
|
131 |
+
"default": {
|
132 |
+
"model": "gpt-3.5-turbo",
|
133 |
+
"temperature": 0,
|
134 |
+
"index": "./index",
|
135 |
+
},
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
SUPPORTED_MODELS = {}
|
140 |
+
|
141 |
+
|
142 |
+
def bot(
|
143 |
+
task: str | None = None,
|
144 |
+
model: str | None = None,
|
145 |
+
index: str | None = None,
|
146 |
+
prompt_template: str | None = None,
|
147 |
+
prompt_variables: list[str] | None = None,
|
148 |
+
memory: str | Memory | None = None,
|
149 |
+
vectorstore: str | VectorStore | None = None,
|
150 |
+
verbose: bool = False,
|
151 |
+
temperature: int = 0,
|
152 |
+
) -> Bot:
|
153 |
+
"""Instanciate a bot based on the provided task. Each supported tasks has it's own default sane defaults.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
task (str | None, optional): The given task. Can be one of the SUPPORTED_TASKS.
|
157 |
+
model (str | None, optional): Model to be used. Can be one of the SUPPORTED_MODELS.
|
158 |
+
index (str | None, optional): Data that the model will load and store index info.
|
159 |
+
Can be either a local file path, a pickle file, or a url of a vector database.
|
160 |
+
By default it will look for a local directory called "files" in the current working directory.
|
161 |
+
prompt_template (str | None, optional): Prompt template to be used. Specify variables with {var_name}.
|
162 |
+
prompt_variables (list[str] | None, optional): Prompt variables to be used in the prompt template.
|
163 |
+
verbose (bool, optional): Verbocity. Defaults to False.
|
164 |
+
temperature (int, optional): Temperature. Defaults to 0.
|
165 |
+
|
166 |
+
Raises:
|
167 |
+
RuntimeError: _description_
|
168 |
+
ValueError: _description_
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Bot: Bot instance
|
172 |
+
"""
|
173 |
+
|
174 |
+
if task is None:
|
175 |
+
raise RuntimeError("Impossible to instantiate a bot without a task.")
|
176 |
+
if task not in SUPPORTED_TASKS:
|
177 |
+
raise ValueError(f"Task {task} is not supported.")
|
178 |
+
|
179 |
+
task_defaults = SUPPORTED_TASKS[task]["default"]
|
180 |
+
|
181 |
+
return SUPPORTED_TASKS[task]["impl"](
|
182 |
+
model=model or task_defaults["model"],
|
183 |
+
index=index or task_defaults["index"],
|
184 |
+
prompt_template=prompt_template,
|
185 |
+
prompt_variables=prompt_variables,
|
186 |
+
temperature=temperature,
|
187 |
+
verbose=verbose,
|
188 |
+
vectorstore=megabots.vectorstore(vectorstore)
|
189 |
+
if isinstance(vectorstore, str)
|
190 |
+
else vectorstore,
|
191 |
+
memory=megabots.memory(memory) if isinstance(memory, str) else memory,
|
192 |
+
)
|
megabots/memory.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.memory import (
|
2 |
+
ConversationBufferMemory,
|
3 |
+
ConversationBufferWindowMemory,
|
4 |
+
ConversationSummaryMemory,
|
5 |
+
ConversationSummaryBufferMemory,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
class ConversationBuffer:
|
10 |
+
def __init__(self):
|
11 |
+
self.memory = ConversationBufferMemory
|
12 |
+
|
13 |
+
|
14 |
+
class ConversationBufferWindow:
|
15 |
+
def __init__(self, memory_window: int):
|
16 |
+
self.memory_window: int = memory_window
|
17 |
+
self.memory = ConversationBufferWindowMemory
|
18 |
+
|
19 |
+
|
20 |
+
class ConversationSummary:
|
21 |
+
def __init__(self):
|
22 |
+
self.memory = ConversationSummaryMemory
|
23 |
+
|
24 |
+
|
25 |
+
class ConversationSummaryBuffer:
|
26 |
+
def __init__(self, max_token_limit: int):
|
27 |
+
self.max_token_limit: int = max_token_limit
|
28 |
+
self.memory = ConversationSummaryBufferMemory
|
29 |
+
|
30 |
+
|
31 |
+
SUPPORTED_MEMORY = {
|
32 |
+
"conversation-buffer": {
|
33 |
+
"impl": ConversationBuffer,
|
34 |
+
"default": {},
|
35 |
+
},
|
36 |
+
"conversation-buffer-window": {
|
37 |
+
"impl": ConversationBufferWindow,
|
38 |
+
"default": {"memory_window": 3},
|
39 |
+
},
|
40 |
+
"conversation-summary": {
|
41 |
+
"impl": ConversationSummary,
|
42 |
+
"default": {},
|
43 |
+
},
|
44 |
+
"conversation-summary-buffer": {
|
45 |
+
"impl": ConversationSummaryBuffer,
|
46 |
+
"default": {"max_token_limit": 40},
|
47 |
+
},
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
Memory = type(
|
52 |
+
"Memory",
|
53 |
+
(
|
54 |
+
ConversationBuffer,
|
55 |
+
ConversationBufferWindow,
|
56 |
+
ConversationSummary,
|
57 |
+
ConversationSummaryBuffer,
|
58 |
+
),
|
59 |
+
{},
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def memory(
|
64 |
+
name: str = "conversation-buffer-window",
|
65 |
+
memory_window: int | None = None,
|
66 |
+
max_token_limit: int | None = None,
|
67 |
+
) -> Memory:
|
68 |
+
if name is None:
|
69 |
+
raise RuntimeError("Impossible to instantiate memory without a name.")
|
70 |
+
|
71 |
+
if name not in SUPPORTED_MEMORY:
|
72 |
+
raise ValueError(f"Memory {name} is not supported.")
|
73 |
+
|
74 |
+
cl = SUPPORTED_MEMORY[name]["impl"]
|
75 |
+
|
76 |
+
if name == "conversation-buffer-window":
|
77 |
+
if max_token_limit != None:
|
78 |
+
raise ValueError(f"max_token_limit cannot be set for {name} memory")
|
79 |
+
return cl(memory_window=memory_window)
|
80 |
+
|
81 |
+
if name == "conversation-summary-buffer":
|
82 |
+
if max_token_limit != None:
|
83 |
+
raise ValueError(f"memory_window cannot be set for {name} memory")
|
84 |
+
return cl(max_token_limit=max_token_limit)
|
85 |
+
|
86 |
+
return SUPPORTED_MEMORY[name]["impl"]()
|
megabots/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from megabots.bot import Bot
|
4 |
+
|
5 |
+
|
6 |
+
def create_api(bot: Bot):
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
@app.get("/v1/ask/{question}")
|
10 |
+
async def ask(question: str):
|
11 |
+
answer = bot.ask(question)
|
12 |
+
return {"answer": answer}
|
13 |
+
|
14 |
+
return app
|
15 |
+
|
16 |
+
|
17 |
+
def create_interface(bot_instance: Bot, examples: list[list[str]] = []):
|
18 |
+
with gr.Blocks() as interface:
|
19 |
+
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
|
20 |
+
msg = gr.Textbox(
|
21 |
+
show_label=False,
|
22 |
+
placeholder="Enter text and press enter, or upload an image",
|
23 |
+
).style(container=False)
|
24 |
+
clear = gr.Button("Clear")
|
25 |
+
|
26 |
+
def user(user_message, history):
|
27 |
+
return "", history + [[user_message, None]]
|
28 |
+
|
29 |
+
def bot(history):
|
30 |
+
response = bot_instance.ask(history[-1][0])
|
31 |
+
history[-1][1] = response
|
32 |
+
return history
|
33 |
+
|
34 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
35 |
+
bot, chatbot, chatbot
|
36 |
+
)
|
37 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
38 |
+
|
39 |
+
return interface
|
megabots/{vectorstores.py → vectorstore.py}
RENAMED
@@ -1,5 +1,5 @@
|
|
1 |
from typing import Type, TypeVar
|
2 |
-
from langchain.vectorstores import Milvus
|
3 |
from abc import ABC
|
4 |
|
5 |
|
@@ -26,7 +26,9 @@ SUPPORTED_VECTORSTORES = {
|
|
26 |
}
|
27 |
|
28 |
|
29 |
-
def vectorstore(
|
|
|
|
|
30 |
"""Return a vectorstore object."""
|
31 |
|
32 |
if name is None:
|
@@ -36,6 +38,6 @@ def vectorstore(name: str) -> VectorStore:
|
|
36 |
raise ValueError(f"Vectorstore {name} is not supported.")
|
37 |
|
38 |
return SUPPORTED_VECTORSTORES[name]["impl"](
|
39 |
-
host=SUPPORTED_VECTORSTORES[name]["default"]["host"],
|
40 |
-
port=SUPPORTED_VECTORSTORES[name]["default"]["port"],
|
41 |
)
|
|
|
1 |
from typing import Type, TypeVar
|
2 |
+
from langchain.vectorstores import Milvus
|
3 |
from abc import ABC
|
4 |
|
5 |
|
|
|
26 |
}
|
27 |
|
28 |
|
29 |
+
def vectorstore(
|
30 |
+
name: str, host: str | None = None, port: int | None = None
|
31 |
+
) -> VectorStore:
|
32 |
"""Return a vectorstore object."""
|
33 |
|
34 |
if name is None:
|
|
|
38 |
raise ValueError(f"Vectorstore {name} is not supported.")
|
39 |
|
40 |
return SUPPORTED_VECTORSTORES[name]["impl"](
|
41 |
+
host=host or SUPPORTED_VECTORSTORES[name]["default"]["host"],
|
42 |
+
port=port or SUPPORTED_VECTORSTORES[name]["default"]["port"],
|
43 |
)
|
tests/test_memory.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from megabots.memory import (
|
3 |
+
ConversationBufferWindow,
|
4 |
+
ConversationSummaryBuffer,
|
5 |
+
memory,
|
6 |
+
Memory,
|
7 |
+
SUPPORTED_MEMORY,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
def test_memory_name_none():
|
12 |
+
with pytest.raises(RuntimeError):
|
13 |
+
memory(name=None)
|
14 |
+
|
15 |
+
|
16 |
+
def test_memory_not_supported():
|
17 |
+
with pytest.raises(ValueError):
|
18 |
+
memory(name="unsupported_memory_type")
|
19 |
+
|
20 |
+
|
21 |
+
def test_memory_conversation_buffer_window():
|
22 |
+
mem_obj = memory(name="conversation-buffer-window", memory_window=5)
|
23 |
+
assert isinstance(mem_obj, ConversationBufferWindow)
|
24 |
+
assert mem_obj.memory_window == 5
|
25 |
+
assert mem_obj.__class__ == SUPPORTED_MEMORY["conversation-buffer-window"]["impl"]
|
26 |
+
|
27 |
+
|
28 |
+
def test_memory_conversation_buffer_window_invalid_max_token_limit():
|
29 |
+
with pytest.raises(ValueError):
|
30 |
+
memory(name="conversation-buffer-window", memory_window=5, max_token_limit=10)
|
31 |
+
|
32 |
+
|
33 |
+
def test_memory_conversation_summary_buffer():
|
34 |
+
mem_obj = memory(name="conversation-summary-buffer", max_token_limit=10)
|
35 |
+
assert isinstance(mem_obj, ConversationSummaryBuffer)
|
36 |
+
assert mem_obj.max_token_limit == 10
|
37 |
+
assert mem_obj.__class__ == SUPPORTED_MEMORY["conversation-summary-buffer"]["impl"]
|
38 |
+
|
39 |
+
|
40 |
+
def test_memory_conversation_summary_buffer_invalid_memory_window():
|
41 |
+
with pytest.raises(ValueError):
|
42 |
+
memory(name="conversation-summary-buffer", memory_window=5, max_token_limit=10)
|