Spaces:
Runtime error
Runtime error
import threading | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from pathlib import Path | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
from langchain.chains.question_answering import load_qa_chain | |
def synchronized_mem(method): | |
def wrapper(self, *args, **kwargs): | |
with self.lock: | |
try: | |
test = args | |
test_2 = kwargs | |
return method(self, *args, **kwargs) | |
except Exception as e: | |
print(f"Failed to execute {method.__name__}: {e}") | |
return wrapper | |
class VectorMemory: | |
"""Simple vector memory implementation using langchain and Chroma""" | |
def __init__(self, loc=None, chunk_size=1000, chunk_overlap_frac=0.1, *args, **kwargs): | |
if loc is None: | |
loc = "./tmp/vector_memory" | |
self.loc = Path(loc) | |
self.chunk_size = chunk_size | |
self.chunk_overlap = chunk_size*chunk_overlap_frac | |
self.embeddings = OpenAIEmbeddings() | |
self.count = 0 | |
self.lock = threading.Lock() | |
self.db = self._init_db() | |
self.qa = self._init_retriever() | |
def _init_db(self): | |
texts = ["init"] # TODO find how to initialize Chroma without any text | |
chroma_db = Chroma.from_texts( | |
texts=texts, | |
embedding=self.embeddings, | |
persist_directory=str(self.loc), | |
) | |
self.count = chroma_db._collection.count() | |
return chroma_db | |
def _init_retriever(self): | |
model = ChatOpenAI(model='gpt-3.5-turbo', temperature=0) | |
qa_chain = load_qa_chain(model, chain_type="stuff") | |
retriever = self.db.as_retriever(search_type="mmr", search_kwargs={"k":10}) | |
qa = RetrievalQA(combine_documents_chain=qa_chain, retriever=retriever) | |
return qa | |
def add_entry(self, entry: str): | |
"""Add an entry to the internal memory. | |
""" | |
text_splitter = CharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separator=" ") | |
texts = text_splitter.split_text(entry) | |
self.db.add_texts(texts) | |
self.count += self.db._collection.count() | |
self.db.persist() | |
return True | |
def search_memory(self, query: str, k=10, type="mmr", distance_threshold=0.5): | |
"""Searching the vector memory for similar entries | |
Args: | |
- query (str): the query to search for | |
- k (int): the number of results to return | |
- type (str): the type of search to perform: "cos" or "mmr" | |
- distance_threshold (float): the similarity threshold to use for the search. Results with distance > similarity_threshold will be dropped. | |
Returns: | |
- texts (list[str]): a list of the top k results | |
""" | |
self.count = self.db._collection.count() | |
print(f"Searching {self.count} entries") | |
if k > self.count: | |
k = self.count - 1 | |
if k <= 0: | |
return None | |
if type == "mmr": | |
texts = self.db.max_marginal_relevance_search(query=query, k=k, fetch_k = min(10,self.count)) | |
texts = [text.page_content for text in texts] | |
elif type == "cos": | |
texts = self.db.similarity_search_with_score(query=query, k=k) | |
texts = [text[0].page_content for text in texts if text[-1] < distance_threshold] | |
return texts | |
def ask_question(self, question: str): | |
"""Ask a question to the vector memory | |
Args: | |
- question (str): the question to ask | |
Returns: | |
- answer (str): the answer to the question | |
""" | |
answer = self.qa.run(question) | |
return answer | |