Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
@Time : 2023/5/25 10:20 | |
@Author : alexanderwu | |
@File : https://github.com/geekan/MetaGPT/blob/main/metagpt/document_store/faiss_store.py | |
""" | |
import pickle | |
from pathlib import Path | |
from typing import Optional | |
import faiss | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import FAISS | |
from autoagents.system.const import DATA_PATH | |
from autoagents.system.document_store.base_store import LocalStore | |
from autoagents.system.document_store.document import Document | |
from autoagents.system.logs import logger | |
class FaissStore(LocalStore): | |
def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_col='output'): | |
self.meta_col = meta_col | |
self.content_col = content_col | |
super().__init__(raw_data, cache_dir) | |
def _load(self) -> Optional["FaissStore"]: | |
index_file, store_file = self._get_index_and_store_fname() | |
if not (index_file.exists() and store_file.exists()): | |
logger.info("Missing at least one of index_file/store_file, load failed and return None") | |
return None | |
index = faiss.read_index(str(index_file)) | |
with open(str(store_file), "rb") as f: | |
store = pickle.load(f) | |
store.index = index | |
return store | |
def _write(self, docs, metadatas): | |
store = FAISS.from_texts(docs, OpenAIEmbeddings(openai_api_version="2020-11-07"), metadatas=metadatas) | |
return store | |
def persist(self): | |
index_file, store_file = self._get_index_and_store_fname() | |
store = self.store | |
index = self.store.index | |
faiss.write_index(store.index, str(index_file)) | |
store.index = None | |
with open(store_file, "wb") as f: | |
pickle.dump(store, f) | |
store.index = index | |
def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs): | |
rsp = self.store.similarity_search(query, k=k) | |
logger.debug(rsp) | |
if expand_cols: | |
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp])) | |
else: | |
return str(sep.join([f"{x.page_content}" for x in rsp])) | |
def write(self): | |
"""根据用户给定的Document(JSON / XLSX等)文件,进行index与库的初始化""" | |
if not self.raw_data.exists(): | |
raise FileNotFoundError | |
doc = Document(self.raw_data, self.content_col, self.meta_col) | |
docs, metadatas = doc.get_docs_and_metadatas() | |
self.store = self._write(docs, metadatas) | |
self.persist() | |
return self.store | |
def add(self, texts: list[str], *args, **kwargs) -> list[str]: | |
"""FIXME: 目前add之后没有更新store""" | |
return self.store.add_texts(texts) | |
def delete(self, *args, **kwargs): | |
"""目前langchain没有提供del接口""" | |
raise NotImplementedError | |
if __name__ == '__main__': | |
faiss_store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json') | |
logger.info(faiss_store.search('油皮洗面奶')) | |
faiss_store.add([f'油皮洗面奶-{i}' for i in range(3)]) | |
logger.info(faiss_store.search('油皮洗面奶')) | |