Spaces:
Sleeping
Sleeping
import yaml | |
from typing import Optional, List | |
from pydantic_settings import BaseSettings | |
from llama_index.core import Settings | |
from llama_index.llms.openai import OpenAI | |
from llama_index.llms.ollama import Ollama | |
from llama_index.llms.groq import Groq | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index.embeddings.ollama import OllamaEmbedding | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from pydantic import Field, ValidationError | |
class ProjectSettings(BaseSettings): | |
OPENAI_API_KEY: Optional[str] = Field(None, env='OPENAI_API_KEY') | |
GROQ_KEY: Optional[str] = Field(None, env='GROQ_KEY') | |
config: Optional[dict] = None | |
moduleList: List[str] = Field(default_factory=list) | |
class Config: | |
env_file = '.env' | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.config = self._read_yaml_config() | |
if not self.config: | |
raise ValidationError("Config file could not be loaded") | |
self._instantiate_services() | |
self.moduleList = self.config.get('userPreference', {})['ModuleList'] | |
def _instantiate_services(self): | |
llm_service = self.config.get('userPreference', {}).get('llmService', '').lower() | |
if llm_service == 'ollama': | |
self._initialize_ollama() | |
elif llm_service == 'openai': | |
self._initialize_openai() | |
elif llm_service == 'groq': | |
self._initialize_groq() | |
else: | |
raise ValueError(f"Invalid LLM service: {llm_service}") | |
def _initialize_ollama(self): | |
ollama_config = self.config.get('OllamaConfig', {}) | |
Settings.llm = Ollama(base_url=ollama_config['baseURL'], model=ollama_config['llm']) | |
Settings.embed_model = OllamaEmbedding(base_url=ollama_config['baseURL'], model_name=ollama_config['llm']) | |
def _initialize_openai(self): | |
openai_config = self.config.get('OpenAIConfig', {}) | |
Settings.llm = OpenAI(model=openai_config['llm'], api_key=self.OPENAI_API_KEY) | |
Settings.embed_model = OpenAIEmbedding(model=openai_config['embeddingModel'], api_key=self.OPENAI_API_KEY) | |
def _initialize_groq(self): | |
groq_config = self.config.get('GroqConfig', {}) | |
Settings.llm = Groq(model=groq_config['llm'], api_key=self.GROQ_KEY) | |
Settings.embed_model = HuggingFaceEmbedding(model_name=groq_config['hfEmbedding']) | |
def _read_yaml_config(): | |
with open("config.yaml", "r") as file: | |
config = yaml.safe_load(file) | |
return config | |
settings = ProjectSettings() | |