import asyncio import json import logging import os import pickle import chromadb import logfire from custom_retriever import CustomRetriever from dotenv import load_dotenv from llama_index.core import Document, VectorStoreIndex from llama_index.core.node_parser import SentenceSplitter from llama_index.core.retrievers import VectorIndexRetriever from llama_index.embeddings.cohere import CohereEmbedding from llama_index.vector_stores.chroma import ChromaVectorStore from utils import init_mongo_db load_dotenv() logfire.configure() if not os.path.exists("data/chroma-db-all_sources"): # Download the vector database from the Hugging Face Hub if it doesn't exist locally # https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main logfire.warn( f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub" ) from huggingface_hub import snapshot_download snapshot_download( repo_id="towardsai-tutors/ai-tutor-vector-db", local_dir="data", repo_type="dataset", ) logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'") def create_docs(input_file: str) -> list[Document]: with open(input_file, "r") as f: documents = [] for line in f: data = json.loads(line) documents.append( Document( doc_id=data["doc_id"], text=data["content"], metadata={ # type: ignore "url": data["url"], "title": data["name"], "tokens": data["tokens"], "retrieve_doc": data["retrieve_doc"], "source": data["source"], }, excluded_llm_metadata_keys=[ "title", "tokens", "retrieve_doc", "source", ], excluded_embed_metadata_keys=[ "url", "tokens", "retrieve_doc", "source", ], ) ) return documents def setup_database(db_collection, dict_file_name) -> CustomRetriever: db = chromadb.PersistentClient(path=f"data/{db_collection}") chroma_collection = db.get_or_create_collection(db_collection) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) embed_model = CohereEmbedding( api_key=os.environ["COHERE_API_KEY"], model_name="embed-english-v3.0", input_type="search_query", ) index = VectorStoreIndex.from_vector_store( vector_store=vector_store, transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], show_progress=True, # use_async=True, ) vector_retriever = VectorIndexRetriever( index=index, similarity_top_k=15, embed_model=embed_model, # use_async=True, ) with open(f"data/{db_collection}/{dict_file_name}", "rb") as f: document_dict = pickle.load(f) return CustomRetriever(vector_retriever, document_dict) custom_retriever_all_sources: CustomRetriever = setup_database( "chroma-db-all_sources", "document_dict_all_sources.pkl", ) CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) MONGODB_URI = os.getenv("MONGODB_URI") AVAILABLE_SOURCES_UI = [ "Transformers Docs", "PEFT Docs", "TRL Docs", "LlamaIndex Docs", "LangChain Docs", "OpenAI Cookbooks", "Towards AI Blog", "8 Hour Primer", "Advanced LLM Developer", "Python Primer", ] AVAILABLE_SOURCES = [ "transformers", "peft", "trl", "llama_index", "langchain", "openai_cookbooks", "tai_blog", "8-hour_primer", "llm_developer", "python_primer", ] mongo_db = ( init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster") if MONGODB_URI else logfire.warn("No mongodb uri found, you will not be able to save data.") ) __all__ = [ "custom_retriever_all_sources", "mongo_db", "CONCURRENCY_COUNT", "AVAILABLE_SOURCES_UI", "AVAILABLE_SOURCES", ]