Spaces:
Running
Running
# src/utils/llm_utils.py | |
from fastapi import HTTPException | |
from typing import Tuple | |
import asyncio | |
import logging | |
from src.llms.openai_llm import OpenAILanguageModel | |
from src.llms.ollama_llm import OllamaLanguageModel | |
from src.llms.bert_llm import BERTLanguageModel | |
from src.llms.falcon_llm import FalconLanguageModel | |
from src.llms.llama_llm import LlamaLanguageModel | |
from src.embeddings.huggingface_embedding import HuggingFaceEmbedding | |
from src.vectorstores.chroma_vectorstore import ChromaVectorStore | |
from src.vectorstores.chroma_manager import ChromaManager | |
from src.utils.logger import logger | |
from config.config import settings | |
# Global vector store instance for reuse | |
_vector_store = None | |
_embedding_model = None | |
_vs_lock = asyncio.Lock() | |
def get_llm_instance(provider: str): | |
""" | |
Get LLM instance based on provider | |
Args: | |
provider (str): Name of the LLM provider | |
Returns: | |
BaseLLM: Instance of the LLM | |
Raises: | |
ValueError: If provider is not supported | |
""" | |
llm_map = { | |
'openai': lambda: OpenAILanguageModel(api_key=settings.OPENAI_API_KEY), | |
'ollama': lambda: OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL), | |
'bert': lambda: BERTLanguageModel(), | |
'falcon': lambda: FalconLanguageModel(), | |
'llama': lambda: LlamaLanguageModel(), | |
} | |
if provider not in llm_map: | |
raise ValueError(f"Unsupported LLM provider: {provider}") | |
return llm_map[provider]() | |
async def get_vector_store() -> Tuple[ChromaVectorStore, HuggingFaceEmbedding]: | |
""" | |
Get vector store and embedding model instances with proper initialization | |
Returns: | |
Tuple[ChromaVectorStore, HuggingFaceEmbedding]: | |
Vector store and embedding model instances | |
""" | |
global _vector_store, _embedding_model, _vs_lock | |
async with _vs_lock: | |
if _vector_store is not None and _embedding_model is not None: | |
return _vector_store, _embedding_model | |
try: | |
# Load embedding model | |
_embedding_model = HuggingFaceEmbedding( | |
model_name=settings.EMBEDDING_MODEL) | |
logger.info(f"Loaded embedding model: {settings.EMBEDDING_MODEL}") | |
# Get ChromaDB client through the manager | |
try: | |
client = await ChromaManager.get_client( | |
persist_directory=settings.CHROMA_PATH, | |
reset_if_needed=True | |
) | |
logger.info("Successfully initialized ChromaDB client") | |
except Exception as e: | |
logger.error(f"Error getting ChromaDB client: {str(e)}") | |
# Try to reset ChromaDB completely | |
await ChromaManager.reset_chroma(settings.CHROMA_PATH) | |
client = await ChromaManager.get_client( | |
persist_directory=settings.CHROMA_PATH | |
) | |
logger.info("Recreated ChromaDB client after reset") | |
# Create and initialize vector store | |
_vector_store = ChromaVectorStore( | |
embedding_function=_embedding_model.embed_documents, | |
persist_directory=settings.CHROMA_PATH, | |
collection_name="documents", | |
client=client | |
) | |
# Initialize the vector store | |
await _vector_store.initialize() | |
logger.info("Vector store successfully initialized") | |
return _vector_store, _embedding_model | |
except Exception as e: | |
logger.error(f"Error initializing vector store: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to initialize vector store: {str(e)}" | |
) | |
async def cleanup_vectorstore(): | |
""" | |
Cleanup and reset vector store resources | |
""" | |
global _vector_store, _embedding_model, _vs_lock | |
async with _vs_lock: | |
_vector_store = None | |
_embedding_model = None | |
# Force garbage collection | |
import gc | |
gc.collect() | |
# Reset ChromaDB completely | |
await ChromaManager.reset_chroma(settings.CHROMA_PATH) | |