Spaces:
Runtime error
Runtime error
import asyncio | |
import json | |
import logging | |
import os | |
import pickle | |
import chromadb | |
import logfire | |
from custom_retriever import CustomRetriever | |
from dotenv import load_dotenv | |
from llama_index.core import Document, VectorStoreIndex | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.retrievers import VectorIndexRetriever | |
from llama_index.embeddings.cohere import CohereEmbedding | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
from utils import init_mongo_db | |
load_dotenv() | |
logfire.configure() | |
if not os.path.exists("data/chroma-db-all_sources"): | |
# Download the vector database from the Hugging Face Hub if it doesn't exist locally | |
# https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main | |
logfire.warn( | |
f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub" | |
) | |
from huggingface_hub import snapshot_download | |
snapshot_download( | |
repo_id="towardsai-tutors/ai-tutor-vector-db", | |
local_dir="data", | |
repo_type="dataset", | |
) | |
logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'") | |
def create_docs(input_file: str) -> list[Document]: | |
with open(input_file, "r") as f: | |
documents = [] | |
for line in f: | |
data = json.loads(line) | |
documents.append( | |
Document( | |
doc_id=data["doc_id"], | |
text=data["content"], | |
metadata={ # type: ignore | |
"url": data["url"], | |
"title": data["name"], | |
"tokens": data["tokens"], | |
"retrieve_doc": data["retrieve_doc"], | |
"source": data["source"], | |
}, | |
excluded_llm_metadata_keys=[ | |
"title", | |
"tokens", | |
"retrieve_doc", | |
"source", | |
], | |
excluded_embed_metadata_keys=[ | |
"url", | |
"tokens", | |
"retrieve_doc", | |
"source", | |
], | |
) | |
) | |
return documents | |
def setup_database(db_collection, dict_file_name) -> CustomRetriever: | |
db = chromadb.PersistentClient(path=f"data/{db_collection}") | |
chroma_collection = db.get_or_create_collection(db_collection) | |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
embed_model = CohereEmbedding( | |
api_key=os.environ["COHERE_API_KEY"], | |
model_name="embed-english-v3.0", | |
input_type="search_query", | |
) | |
index = VectorStoreIndex.from_vector_store( | |
vector_store=vector_store, | |
transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], | |
show_progress=True, | |
# use_async=True, | |
) | |
vector_retriever = VectorIndexRetriever( | |
index=index, | |
similarity_top_k=15, | |
embed_model=embed_model, | |
# use_async=True, | |
) | |
with open(f"data/{db_collection}/{dict_file_name}", "rb") as f: | |
document_dict = pickle.load(f) | |
return CustomRetriever(vector_retriever, document_dict) | |
custom_retriever_all_sources: CustomRetriever = setup_database( | |
"chroma-db-all_sources", | |
"document_dict_all_sources.pkl", | |
) | |
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) | |
MONGODB_URI = os.getenv("MONGODB_URI") | |
AVAILABLE_SOURCES_UI = [ | |
"Transformers Docs", | |
"PEFT Docs", | |
"TRL Docs", | |
"LlamaIndex Docs", | |
"LangChain Docs", | |
"OpenAI Cookbooks", | |
"Towards AI Blog", | |
"8 Hour Primer", | |
"Advanced LLM Developer", | |
"Python Primer", | |
] | |
AVAILABLE_SOURCES = [ | |
"transformers", | |
"peft", | |
"trl", | |
"llama_index", | |
"langchain", | |
"openai_cookbooks", | |
"tai_blog", | |
"8-hour_primer", | |
"llm_developer", | |
"python_primer", | |
] | |
mongo_db = ( | |
init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster") | |
if MONGODB_URI | |
else logfire.warn("No mongodb uri found, you will not be able to save data.") | |
) | |
__all__ = [ | |
"custom_retriever_all_sources", | |
"mongo_db", | |
"CONCURRENCY_COUNT", | |
"AVAILABLE_SOURCES_UI", | |
"AVAILABLE_SOURCES", | |
] | |