momegas commited on
Commit
a091159
·
1 Parent(s): 44e486f

Some refactoring the memory inerface

Browse files
megabots/__init__.py CHANGED
@@ -1,256 +1,9 @@
1
- from typing import Any
2
- from langchain.llms import OpenAI
3
- from langchain.chat_models import ChatOpenAI
4
- from langchain.embeddings import OpenAIEmbeddings
5
- from langchain.chains.qa_with_sources import load_qa_with_sources_chain
6
- from langchain.vectorstores.faiss import FAISS
7
- import gradio as gr
8
- from fastapi import FastAPI
9
- import pickle
10
- import os
11
- from dotenv import load_dotenv
12
- from langchain.prompts import PromptTemplate
13
- from langchain.chains.question_answering import load_qa_chain
14
- from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
15
- from langchain.document_loaders import DirectoryLoader
16
- from megabots.vectorstores import VectorStore, vectorstore
17
- from langchain.memory import (
18
- ConversationBufferMemory,
19
- ConversationBufferWindowMemory,
20
- ConversationSummaryMemory,
21
- ConversationSummaryBufferMemory,
22
- )
23
-
24
- load_dotenv()
25
-
26
-
27
- class Bot:
28
- def __init__(
29
- self,
30
- model: str | None = None,
31
- prompt_template: str | None = None,
32
- prompt_variables: list[str] | None = None,
33
- index: str | None = None,
34
- sources: bool | None = False,
35
- vectorstore: VectorStore | None = None,
36
- memory: str | None = None,
37
- memory_window: int = 3,
38
- verbose: bool = False,
39
- temperature: int = 0,
40
- ):
41
- self.select_model(model, temperature)
42
- self.create_loader(index)
43
- self.load_or_create_index(index, vectorstore)
44
-
45
- # Load the question-answering chain for the selected model
46
- self.chain = self.create_chain(
47
- prompt_template, prompt_variables, sources=sources, verbose=verbose
48
- )
49
-
50
-
51
-
52
- def create_chain(
53
- self,
54
- prompt_template: str | None = None,
55
- prompt_variables: list[str] | None = None,
56
- sources: bool | None = False,
57
- verbose: bool = False,
58
- ):
59
- prompt = (
60
- PromptTemplate(template=prompt_template, input_variables=prompt_variables)
61
- if prompt_template is not None and prompt_variables is not None
62
- else QA_PROMPT
63
- )
64
- # TODO: Changing the prompt here is not working. Leave it as is for now.
65
- # Reference: https://github.com/hwchase17/langchain/issues/2858
66
- if sources:
67
- return load_qa_with_sources_chain(
68
- self.llm, chain_type="stuff", verbose=verbose
69
- )
70
- return load_qa_chain(
71
- self.llm, chain_type="stuff", verbose=verbose, prompt=prompt
72
- )
73
-
74
- def select_model(self, model: str | None, temperature: float):
75
- # Select and set the appropriate model based on the provided input
76
- if model is None or model == "gpt-3.5-turbo":
77
- print("Using model: gpt-3.5-turbo")
78
- self.llm = ChatOpenAI(temperature=temperature)
79
-
80
- if model == "text-davinci-003":
81
- print("Using model: text-davinci-003")
82
- self.llm = OpenAI(temperature=temperature)
83
-
84
- def create_loader(self, index: str | None):
85
- # Create a loader based on the provided directory (either local or S3)
86
- if index is None:
87
- raise RuntimeError(
88
- """
89
- Impossible to find a valid index.
90
- Either provide a valid path to a pickle file or a directory.
91
- """
92
- )
93
- self.loader = DirectoryLoader(index, recursive=True)
94
-
95
- def load_or_create_index(self, index: str, vectorstore: VectorStore | None = None):
96
- # Load an existing index from disk or create a new one if not available
97
- if vectorstore is not None:
98
- self.search_index = vectorstore.client.from_documents(
99
- self.loader.load_and_split(),
100
- OpenAIEmbeddings(),
101
- connection_args={"host": vectorstore.host, "port": vectorstore.port},
102
- )
103
- return
104
-
105
- # Is pickle
106
- if index is not None and "pkl" in index or "pickle" in index:
107
- print("Loading path from pickle file: ", index, "...")
108
- with open(index, "rb") as f:
109
- self.search_index = pickle.load(f)
110
- return
111
-
112
- # Is directory
113
- if index is not None and os.path.isdir(index):
114
- print("Creating index...")
115
- self.search_index = FAISS.from_documents(
116
- self.loader.load_and_split(), OpenAIEmbeddings()
117
- )
118
- return
119
-
120
- raise RuntimeError(
121
- """
122
- Impossible to find a valid index.
123
- Either provide a valid path to a pickle file or a directory.
124
- """
125
- )
126
-
127
- def save_index(self, index_path: str):
128
- # Save the index to the specified path
129
- with open(index_path, "wb") as f:
130
- pickle.dump(self.search_index, f)
131
-
132
- def ask(self, question: str, k=1) -> str:
133
- # Retrieve the answer to the given question and return it
134
- input_documents = self.search_index.similarity_search(question, k=k)
135
- answer = self.chain.run(input_documents=input_documents, question=question)
136
- return answer
137
-
138
-
139
- SUPPORTED_TASKS = {
140
- "qna-over-docs": {
141
- "impl": Bot,
142
- "default": {
143
- "model": "gpt-3.5-turbo",
144
- "temperature": 0,
145
- "index": "./index",
146
- },
147
- }
148
- }
149
 
150
- SUPPORTED_MODELS = {}
151
 
152
- SUPPORTED_MEMORY = {
153
- "conversation-buffer-window": {
154
- "impl": ConversationBufferWindowMemory,
155
- "default": {"memory_window": 3},
156
- },
157
- "conversation-buffer": {
158
- "impl": ConversationBufferMemory,
159
- "default": {},
160
- },
161
- "conversation-summary": {
162
- "impl": ConversationSummaryMemory,
163
- "default": {},
164
- "conversation-summary-buffer": {
165
- "impl": ConversationSummaryBufferMemory,
166
- "default": {
167
- "max_token_limit":40
168
- }
169
- },
170
- }
171
-
172
-
173
- def bot(
174
- task: str | None = None,
175
- model: str | None = None,
176
- index: str | None = None,
177
- prompt_template: str | None = None,
178
- prompt_variables: list[str] | None = None,
179
- memory: str | None = None,
180
- memory_window: int = 3,
181
- verbose: bool = False,
182
- temperature: int = 0,
183
- **kwargs,
184
- ) -> Bot:
185
- """Instanciate a bot based on the provided task. Each supported tasks has it's own default sane defaults.
186
-
187
- Args:
188
- task (str | None, optional): The given task. Can be one of the SUPPORTED_TASKS.
189
- model (str | None, optional): Model to be used. Can be one of the SUPPORTED_MODELS.
190
- index (str | None, optional): Data that the model will load and store index info.
191
- Can be either a local file path, a pickle file, or a url of a vector database.
192
- By default it will look for a local directory called "files" in the current working directory.
193
- prompt_template (str | None, optional): Prompt template to be used. Specify variables with {var_name}.
194
- prompt_variables (list[str] | None, optional): Prompt variables to be used in the prompt template.
195
- verbose (bool, optional): Verbocity. Defaults to False.
196
- temperature (int, optional): Temperature. Defaults to 0.
197
-
198
- Raises:
199
- RuntimeError: _description_
200
- ValueError: _description_
201
-
202
- Returns:
203
- Bot: Bot instance
204
- """
205
-
206
- if task is None:
207
- raise RuntimeError("Impossible to instantiate a bot without a task.")
208
- if task not in SUPPORTED_TASKS:
209
- raise ValueError(f"Task {task} is not supported.")
210
-
211
- task_defaults = SUPPORTED_TASKS[task]["default"]
212
- return SUPPORTED_TASKS[task]["impl"](
213
- model=model or task_defaults["model"],
214
- index=index or task_defaults["index"],
215
- prompt_template=prompt_template,
216
- prompt_variables=prompt_variables,
217
- temperature=temperature,
218
- verbose=verbose,
219
- **kwargs,
220
- )
221
-
222
-
223
- def create_api(bot: Bot):
224
- app = FastAPI()
225
-
226
- @app.get("/v1/ask/{question}")
227
- async def ask(question: str):
228
- answer = bot.ask(question)
229
- return {"answer": answer}
230
-
231
- return app
232
-
233
-
234
- def create_interface(bot_instance: Bot, examples: list[list[str]] = []):
235
- with gr.Blocks() as interface:
236
- chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
237
- msg = gr.Textbox(
238
- show_label=False,
239
- placeholder="Enter text and press enter, or upload an image",
240
- ).style(container=False)
241
- clear = gr.Button("Clear")
242
-
243
- def user(user_message, history):
244
- return "", history + [[user_message, None]]
245
-
246
- def bot(history):
247
- response = bot_instance.ask(history[-1][0])
248
- history[-1][1] = response
249
- return history
250
-
251
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
252
- bot, chatbot, chatbot
253
- )
254
- clear.click(lambda: None, None, chatbot, queue=False)
255
 
256
- return interface
 
1
+ from megabots.vectorstore import VectorStore, vectorstore
2
+ from megabots.memory import Memory, memory
3
+ from megabots.bot import Bot, bot
4
+ from megabots.utils import create_api, create_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
6
 
7
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ load_dotenv()
megabots/bot.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from langchain.llms import OpenAI
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.embeddings import OpenAIEmbeddings
5
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
6
+ from langchain.vectorstores.faiss import FAISS
7
+ import pickle
8
+ import os
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.chains.question_answering import load_qa_chain
11
+ from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
12
+ from langchain.document_loaders import DirectoryLoader
13
+ from megabots.vectorstore import VectorStore
14
+ from megabots.memory import Memory
15
+ import megabots
16
+
17
+
18
+ class Bot:
19
+ def __init__(
20
+ self,
21
+ model: str | None = None,
22
+ prompt_template: str | None = None,
23
+ prompt_variables: list[str] | None = None,
24
+ index: str | None = None,
25
+ sources: bool | None = False,
26
+ vectorstore: VectorStore | None = None,
27
+ memory: Memory | None = None,
28
+ verbose: bool = False,
29
+ temperature: int = 0,
30
+ ):
31
+ self.select_model(model, temperature)
32
+ self.create_loader(index)
33
+ self.load_or_create_index(index, vectorstore)
34
+ self.vectorstore = vectorstore
35
+ self.memory = memory
36
+ # Load the question-answering chain for the selected model
37
+ self.chain = self.create_chain(
38
+ prompt_template, prompt_variables, sources=sources, verbose=verbose
39
+ )
40
+
41
+ def create_chain(
42
+ self,
43
+ prompt_template: str | None = None,
44
+ prompt_variables: list[str] | None = None,
45
+ sources: bool | None = False,
46
+ verbose: bool = False,
47
+ ):
48
+ prompt = (
49
+ PromptTemplate(template=prompt_template, input_variables=prompt_variables)
50
+ if prompt_template is not None and prompt_variables is not None
51
+ else QA_PROMPT
52
+ )
53
+ # TODO: Changing the prompt here is not working. Leave it as is for now.
54
+ # Reference: https://github.com/hwchase17/langchain/issues/2858
55
+ if sources:
56
+ return load_qa_with_sources_chain(
57
+ self.llm, chain_type="stuff", verbose=verbose
58
+ )
59
+ return load_qa_chain(
60
+ self.llm, chain_type="stuff", verbose=verbose, prompt=prompt
61
+ )
62
+
63
+ def select_model(self, model: str | None, temperature: float):
64
+ # Select and set the appropriate model based on the provided input
65
+ if model is None or model == "gpt-3.5-turbo":
66
+ print("Using model: gpt-3.5-turbo")
67
+ self.llm = ChatOpenAI(temperature=temperature)
68
+
69
+ if model == "text-davinci-003":
70
+ print("Using model: text-davinci-003")
71
+ self.llm = OpenAI(temperature=temperature)
72
+
73
+ def create_loader(self, index: str | None):
74
+ # Create a loader based on the provided directory (either local or S3)
75
+ if index is None:
76
+ raise RuntimeError(
77
+ """
78
+ Impossible to find a valid index.
79
+ Either provide a valid path to a pickle file or a directory.
80
+ """
81
+ )
82
+ self.loader = DirectoryLoader(index, recursive=True)
83
+
84
+ def load_or_create_index(self, index: str, vectorstore: VectorStore | None = None):
85
+ # Load an existing index from disk or create a new one if not available
86
+ if vectorstore is not None:
87
+ self.search_index = vectorstore.client.from_documents(
88
+ self.loader.load_and_split(),
89
+ OpenAIEmbeddings(),
90
+ connection_args={"host": vectorstore.host, "port": vectorstore.port},
91
+ )
92
+ return
93
+
94
+ # Is pickle
95
+ if index is not None and "pkl" in index or "pickle" in index:
96
+ print("Loading path from pickle file: ", index, "...")
97
+ with open(index, "rb") as f:
98
+ self.search_index = pickle.load(f)
99
+ return
100
+
101
+ # Is directory
102
+ if index is not None and os.path.isdir(index):
103
+ print("Creating index...")
104
+ self.search_index = FAISS.from_documents(
105
+ self.loader.load_and_split(), OpenAIEmbeddings()
106
+ )
107
+ return
108
+
109
+ raise RuntimeError(
110
+ """
111
+ Impossible to find a valid index.
112
+ Either provide a valid path to a pickle file or a directory.
113
+ """
114
+ )
115
+
116
+ def save_index(self, index_path: str):
117
+ # Save the index to the specified path
118
+ with open(index_path, "wb") as f:
119
+ pickle.dump(self.search_index, f)
120
+
121
+ def ask(self, question: str, k=1) -> str:
122
+ # Retrieve the answer to the given question and return it
123
+ input_documents = self.search_index.similarity_search(question, k=k)
124
+ answer = self.chain.run(input_documents=input_documents, question=question)
125
+ return answer
126
+
127
+
128
+ SUPPORTED_TASKS = {
129
+ "qna-over-docs": {
130
+ "impl": Bot,
131
+ "default": {
132
+ "model": "gpt-3.5-turbo",
133
+ "temperature": 0,
134
+ "index": "./index",
135
+ },
136
+ }
137
+ }
138
+
139
+ SUPPORTED_MODELS = {}
140
+
141
+
142
+ def bot(
143
+ task: str | None = None,
144
+ model: str | None = None,
145
+ index: str | None = None,
146
+ prompt_template: str | None = None,
147
+ prompt_variables: list[str] | None = None,
148
+ memory: str | Memory | None = None,
149
+ vectorstore: str | VectorStore | None = None,
150
+ verbose: bool = False,
151
+ temperature: int = 0,
152
+ ) -> Bot:
153
+ """Instanciate a bot based on the provided task. Each supported tasks has it's own default sane defaults.
154
+
155
+ Args:
156
+ task (str | None, optional): The given task. Can be one of the SUPPORTED_TASKS.
157
+ model (str | None, optional): Model to be used. Can be one of the SUPPORTED_MODELS.
158
+ index (str | None, optional): Data that the model will load and store index info.
159
+ Can be either a local file path, a pickle file, or a url of a vector database.
160
+ By default it will look for a local directory called "files" in the current working directory.
161
+ prompt_template (str | None, optional): Prompt template to be used. Specify variables with {var_name}.
162
+ prompt_variables (list[str] | None, optional): Prompt variables to be used in the prompt template.
163
+ verbose (bool, optional): Verbocity. Defaults to False.
164
+ temperature (int, optional): Temperature. Defaults to 0.
165
+
166
+ Raises:
167
+ RuntimeError: _description_
168
+ ValueError: _description_
169
+
170
+ Returns:
171
+ Bot: Bot instance
172
+ """
173
+
174
+ if task is None:
175
+ raise RuntimeError("Impossible to instantiate a bot without a task.")
176
+ if task not in SUPPORTED_TASKS:
177
+ raise ValueError(f"Task {task} is not supported.")
178
+
179
+ task_defaults = SUPPORTED_TASKS[task]["default"]
180
+
181
+ return SUPPORTED_TASKS[task]["impl"](
182
+ model=model or task_defaults["model"],
183
+ index=index or task_defaults["index"],
184
+ prompt_template=prompt_template,
185
+ prompt_variables=prompt_variables,
186
+ temperature=temperature,
187
+ verbose=verbose,
188
+ vectorstore=megabots.vectorstore(vectorstore)
189
+ if isinstance(vectorstore, str)
190
+ else vectorstore,
191
+ memory=megabots.memory(memory) if isinstance(memory, str) else memory,
192
+ )
megabots/memory.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.memory import (
2
+ ConversationBufferMemory,
3
+ ConversationBufferWindowMemory,
4
+ ConversationSummaryMemory,
5
+ ConversationSummaryBufferMemory,
6
+ )
7
+
8
+
9
+ class ConversationBuffer:
10
+ def __init__(self):
11
+ self.memory = ConversationBufferMemory
12
+
13
+
14
+ class ConversationBufferWindow:
15
+ def __init__(self, memory_window: int):
16
+ self.memory_window: int = memory_window
17
+ self.memory = ConversationBufferWindowMemory
18
+
19
+
20
+ class ConversationSummary:
21
+ def __init__(self):
22
+ self.memory = ConversationSummaryMemory
23
+
24
+
25
+ class ConversationSummaryBuffer:
26
+ def __init__(self, max_token_limit: int):
27
+ self.max_token_limit: int = max_token_limit
28
+ self.memory = ConversationSummaryBufferMemory
29
+
30
+
31
+ SUPPORTED_MEMORY = {
32
+ "conversation-buffer": {
33
+ "impl": ConversationBuffer,
34
+ "default": {},
35
+ },
36
+ "conversation-buffer-window": {
37
+ "impl": ConversationBufferWindow,
38
+ "default": {"memory_window": 3},
39
+ },
40
+ "conversation-summary": {
41
+ "impl": ConversationSummary,
42
+ "default": {},
43
+ },
44
+ "conversation-summary-buffer": {
45
+ "impl": ConversationSummaryBuffer,
46
+ "default": {"max_token_limit": 40},
47
+ },
48
+ }
49
+
50
+
51
+ Memory = type(
52
+ "Memory",
53
+ (
54
+ ConversationBuffer,
55
+ ConversationBufferWindow,
56
+ ConversationSummary,
57
+ ConversationSummaryBuffer,
58
+ ),
59
+ {},
60
+ )
61
+
62
+
63
+ def memory(
64
+ name: str = "conversation-buffer-window",
65
+ memory_window: int | None = None,
66
+ max_token_limit: int | None = None,
67
+ ) -> Memory:
68
+ if name is None:
69
+ raise RuntimeError("Impossible to instantiate memory without a name.")
70
+
71
+ if name not in SUPPORTED_MEMORY:
72
+ raise ValueError(f"Memory {name} is not supported.")
73
+
74
+ cl = SUPPORTED_MEMORY[name]["impl"]
75
+
76
+ if name == "conversation-buffer-window":
77
+ if max_token_limit != None:
78
+ raise ValueError(f"max_token_limit cannot be set for {name} memory")
79
+ return cl(memory_window=memory_window)
80
+
81
+ if name == "conversation-summary-buffer":
82
+ if max_token_limit != None:
83
+ raise ValueError(f"memory_window cannot be set for {name} memory")
84
+ return cl(max_token_limit=max_token_limit)
85
+
86
+ return SUPPORTED_MEMORY[name]["impl"]()
megabots/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastapi import FastAPI
3
+ from megabots.bot import Bot
4
+
5
+
6
+ def create_api(bot: Bot):
7
+ app = FastAPI()
8
+
9
+ @app.get("/v1/ask/{question}")
10
+ async def ask(question: str):
11
+ answer = bot.ask(question)
12
+ return {"answer": answer}
13
+
14
+ return app
15
+
16
+
17
+ def create_interface(bot_instance: Bot, examples: list[list[str]] = []):
18
+ with gr.Blocks() as interface:
19
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
20
+ msg = gr.Textbox(
21
+ show_label=False,
22
+ placeholder="Enter text and press enter, or upload an image",
23
+ ).style(container=False)
24
+ clear = gr.Button("Clear")
25
+
26
+ def user(user_message, history):
27
+ return "", history + [[user_message, None]]
28
+
29
+ def bot(history):
30
+ response = bot_instance.ask(history[-1][0])
31
+ history[-1][1] = response
32
+ return history
33
+
34
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
35
+ bot, chatbot, chatbot
36
+ )
37
+ clear.click(lambda: None, None, chatbot, queue=False)
38
+
39
+ return interface
megabots/{vectorstores.py → vectorstore.py} RENAMED
@@ -1,5 +1,5 @@
1
  from typing import Type, TypeVar
2
- from langchain.vectorstores import Milvus, Qdrant
3
  from abc import ABC
4
 
5
 
@@ -26,7 +26,9 @@ SUPPORTED_VECTORSTORES = {
26
  }
27
 
28
 
29
- def vectorstore(name: str) -> VectorStore:
 
 
30
  """Return a vectorstore object."""
31
 
32
  if name is None:
@@ -36,6 +38,6 @@ def vectorstore(name: str) -> VectorStore:
36
  raise ValueError(f"Vectorstore {name} is not supported.")
37
 
38
  return SUPPORTED_VECTORSTORES[name]["impl"](
39
- host=SUPPORTED_VECTORSTORES[name]["default"]["host"],
40
- port=SUPPORTED_VECTORSTORES[name]["default"]["port"],
41
  )
 
1
  from typing import Type, TypeVar
2
+ from langchain.vectorstores import Milvus
3
  from abc import ABC
4
 
5
 
 
26
  }
27
 
28
 
29
+ def vectorstore(
30
+ name: str, host: str | None = None, port: int | None = None
31
+ ) -> VectorStore:
32
  """Return a vectorstore object."""
33
 
34
  if name is None:
 
38
  raise ValueError(f"Vectorstore {name} is not supported.")
39
 
40
  return SUPPORTED_VECTORSTORES[name]["impl"](
41
+ host=host or SUPPORTED_VECTORSTORES[name]["default"]["host"],
42
+ port=port or SUPPORTED_VECTORSTORES[name]["default"]["port"],
43
  )
tests/test_memory.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from megabots.memory import (
3
+ ConversationBufferWindow,
4
+ ConversationSummaryBuffer,
5
+ memory,
6
+ Memory,
7
+ SUPPORTED_MEMORY,
8
+ )
9
+
10
+
11
+ def test_memory_name_none():
12
+ with pytest.raises(RuntimeError):
13
+ memory(name=None)
14
+
15
+
16
+ def test_memory_not_supported():
17
+ with pytest.raises(ValueError):
18
+ memory(name="unsupported_memory_type")
19
+
20
+
21
+ def test_memory_conversation_buffer_window():
22
+ mem_obj = memory(name="conversation-buffer-window", memory_window=5)
23
+ assert isinstance(mem_obj, ConversationBufferWindow)
24
+ assert mem_obj.memory_window == 5
25
+ assert mem_obj.__class__ == SUPPORTED_MEMORY["conversation-buffer-window"]["impl"]
26
+
27
+
28
+ def test_memory_conversation_buffer_window_invalid_max_token_limit():
29
+ with pytest.raises(ValueError):
30
+ memory(name="conversation-buffer-window", memory_window=5, max_token_limit=10)
31
+
32
+
33
+ def test_memory_conversation_summary_buffer():
34
+ mem_obj = memory(name="conversation-summary-buffer", max_token_limit=10)
35
+ assert isinstance(mem_obj, ConversationSummaryBuffer)
36
+ assert mem_obj.max_token_limit == 10
37
+ assert mem_obj.__class__ == SUPPORTED_MEMORY["conversation-summary-buffer"]["impl"]
38
+
39
+
40
+ def test_memory_conversation_summary_buffer_invalid_memory_window():
41
+ with pytest.raises(ValueError):
42
+ memory(name="conversation-summary-buffer", memory_window=5, max_token_limit=10)