ciyidogan commited on
Commit
8304bb2
·
verified ·
1 Parent(s): 5e9eccb

Upload 5 files

Browse files
llm/llm_factory.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Provider Factory for Flare
3
+ """
4
+ import os
5
+ from typing import Optional
6
+ from dotenv import load_dotenv
7
+
8
+ from llm_interface import LLMInterface
9
+ from llm_spark import SparkLLM
10
+ from llm_openai import OpenAILLM
11
+ from config_provider import ConfigProvider
12
+ from logger import log_info, log_error, log_warning, log_debug
13
+
14
+ class LLMFactory:
15
+ @staticmethod
16
+ def create_provider() -> LLMInterface:
17
+ """Create LLM provider based on configuration"""
18
+ cfg = ConfigProvider.get()
19
+ llm_config = cfg.global_config.llm_provider
20
+
21
+ if not llm_config:
22
+ raise ValueError("No LLM provider configured")
23
+
24
+ provider_name = llm_config.name
25
+ log_info(f"🏭 Creating LLM provider: {provider_name}")
26
+
27
+ # Get provider definition
28
+ provider_def = cfg.global_config.get_provider_config("llm", provider_name)
29
+ if not provider_def:
30
+ raise ValueError(f"Unknown LLM provider: {provider_name}")
31
+
32
+ # Get API key
33
+ api_key = LLMFactory._get_api_key(provider_name, llm_config.api_key)
34
+
35
+ # Create provider based on name
36
+ if provider_name == "spark":
37
+ return LLMFactory._create_spark_provider(llm_config, api_key, provider_def)
38
+ elif provider_name == "spark_cloud":
39
+ return LLMFactory._create_spark_provider(llm_config, api_key, provider_def)
40
+ elif provider_name in ["gpt-4o", "gpt-4o-mini"]:
41
+ return LLMFactory._create_gpt_provider(llm_config, api_key, provider_def)
42
+ else:
43
+ raise ValueError(f"Unsupported LLM provider: {provider_name}")
44
+
45
+ @staticmethod
46
+ def _create_spark_provider(llm_config, api_key, provider_def):
47
+ """Create Spark LLM provider"""
48
+ endpoint = llm_config.endpoint
49
+ if not endpoint:
50
+ raise ValueError("Spark endpoint not configured")
51
+
52
+ # Determine variant based on environment
53
+ is_cloud = bool(os.environ.get("SPACE_ID"))
54
+ variant = "hfcloud" if is_cloud else "on-premise"
55
+
56
+ return SparkLLM(
57
+ spark_endpoint=endpoint,
58
+ spark_token=api_key,
59
+ provider_variant=variant,
60
+ settings=llm_config.settings
61
+ )
62
+
63
+ @staticmethod
64
+ def _create_gpt_provider(llm_config, api_key, provider_def):
65
+ """Create OpenAI GPT provider"""
66
+ return OpenAILLM(
67
+ api_key=api_key,
68
+ model=llm_config.name,
69
+ settings=llm_config.settings
70
+ )
71
+
72
+ @staticmethod
73
+ def _get_api_key(provider_name: str, configured_key: Optional[str]) -> str:
74
+ """Get API key from config or environment"""
75
+ # First try configured key
76
+ if configured_key:
77
+ # Handle encrypted keys
78
+ if configured_key.startswith("enc:"):
79
+ from encryption_utils import decrypt
80
+ return decrypt(configured_key)
81
+ return configured_key
82
+
83
+ # Then try environment variables
84
+ env_mappings = {
85
+ "spark": "SPARK_TOKEN",
86
+ "gpt-4o": "OPENAI_API_KEY",
87
+ "gpt-4o-mini": "OPENAI_API_KEY"
88
+ }
89
+
90
+ env_var = env_mappings.get(provider_name)
91
+ if env_var:
92
+ key = os.environ.get(env_var)
93
+ if key:
94
+ log_info(f"📌 Using API key from environment: {env_var}")
95
+ return key
96
+
97
+ raise ValueError(f"No API key found for provider: {provider_name}")
llm/llm_interface.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Provider Interface for Flare
3
+ """
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Optional, Any
7
+
8
+ class LLMInterface(ABC):
9
+ """Abstract base class for LLM providers"""
10
+
11
+ def __init__(self, settings: Dict[str, Any] = None):
12
+ """Initialize with provider settings"""
13
+ self.settings = settings or {}
14
+ self.internal_prompt = self.settings.get("internal_prompt", "")
15
+ self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
16
+
17
+ @abstractmethod
18
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
19
+ """Generate response from LLM"""
20
+ pass
21
+
22
+ @abstractmethod
23
+ async def startup(self, project_config: Dict) -> bool:
24
+ """Initialize LLM with project config"""
25
+ pass
26
+
27
+ @abstractmethod
28
+ def get_provider_name(self) -> str:
29
+ """Get provider name for logging"""
30
+ pass
31
+
32
+ @abstractmethod
33
+ def get_model_info(self) -> Dict[str, Any]:
34
+ """Get model information"""
35
+ pass
llm/llm_openai.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI GPT Implementation
3
+ """
4
+ import os
5
+ import openai
6
+ from typing import Dict, List, Any
7
+ from llm_interface import LLMInterface
8
+ from logger import log_info, log_error, log_warning, log_debug, LogTimer
9
+
10
+ DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
11
+ MAX_RESPONSE_LENGTH = 4096 # Max response length
12
+
13
+ class OpenAILLM(LLMInterface):
14
+ """OpenAI GPT integration with improved error handling"""
15
+
16
+ def __init__(self, api_key: str, model: str = "gpt-4", settings: Dict[str, Any] = None):
17
+ super().__init__(settings)
18
+ self.api_key = api_key
19
+ self.model = model
20
+ self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
21
+ openai.api_key = api_key
22
+ log_info(f"🔌 OpenAI LLM initialized", model=self.model, timeout=self.timeout)
23
+
24
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
25
+ """Generate response with consistent error handling"""
26
+
27
+ # Build messages
28
+ messages = []
29
+ if system_prompt:
30
+ messages.append({"role": "system", "content": system_prompt})
31
+
32
+ # Add context
33
+ for msg in context[-10:]: # Last 10 messages
34
+ role = "assistant" if msg.get("role") == "assistant" else "user"
35
+ messages.append({"role": role, "content": msg.get("content", "")})
36
+
37
+ # Add current input
38
+ messages.append({"role": "user", "content": user_input})
39
+
40
+ try:
41
+ with LogTimer(f"OpenAI {self.model} request"):
42
+ # Use async client
43
+ client = openai.AsyncOpenAI(
44
+ api_key=self.api_key,
45
+ timeout=self.timeout
46
+ )
47
+
48
+ response = await client.chat.completions.create(
49
+ model=self.model,
50
+ messages=messages,
51
+ max_tokens=self.settings.get("max_tokens", 2048),
52
+ temperature=self.settings.get("temperature", 0.7),
53
+ stream=False
54
+ )
55
+
56
+ # Extract content
57
+ content = response.choices[0].message.content
58
+
59
+ # Check length
60
+ if len(content) > MAX_RESPONSE_LENGTH:
61
+ log_warning(f"Response exceeded max length, truncating",
62
+ original_length=len(content),
63
+ max_length=MAX_RESPONSE_LENGTH)
64
+ content = content[:MAX_RESPONSE_LENGTH] + "..."
65
+
66
+ # Log token usage
67
+ if response.usage:
68
+ log_info(f"Token usage",
69
+ prompt_tokens=response.usage.prompt_tokens,
70
+ completion_tokens=response.usage.completion_tokens,
71
+ total_tokens=response.usage.total_tokens)
72
+
73
+ return content
74
+
75
+ except openai.RateLimitError as e:
76
+ log_warning("OpenAI rate limit", error=str(e))
77
+ raise
78
+ except openai.APITimeoutError as e:
79
+ log_error("OpenAI timeout", error=str(e))
80
+ raise
81
+ except openai.APIError as e:
82
+ log_error("OpenAI API error",
83
+ status_code=e.status_code if hasattr(e, 'status_code') else None,
84
+ error=str(e))
85
+ raise
86
+ except Exception as e:
87
+ log_error("OpenAI unexpected error", error=str(e))
88
+ raise
89
+
90
+ async def startup(self, project_config: Dict) -> bool:
91
+ """OpenAI doesn't need startup"""
92
+ log_info("OpenAI startup called (no-op)")
93
+ return True
94
+
95
+ def get_provider_name(self) -> str:
96
+ return f"openai-{self.model}"
97
+
98
+ def get_model_info(self) -> Dict[str, Any]:
99
+ return {
100
+ "provider": "openai",
101
+ "model": self.model,
102
+ "max_tokens": self.settings.get("max_tokens", 2048),
103
+ "temperature": self.settings.get("temperature", 0.7)
104
+ }
llm/llm_spark.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spark LLM Implementation
3
+ """
4
+ import os
5
+ import httpx
6
+ import json
7
+ from typing import Dict, List, Any, AsyncIterator
8
+ from llm_interface import LLMInterface
9
+ from logger import log_info, log_error, log_warning, log_debug
10
+
11
+ # Get timeout from environment
12
+ DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60"))
13
+ MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096"))
14
+
15
+ class SparkLLM(LLMInterface):
16
+ """Spark LLM integration with improved error handling"""
17
+
18
+ def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
19
+ super().__init__(settings)
20
+ self.spark_endpoint = spark_endpoint.rstrip("/")
21
+ self.spark_token = spark_token
22
+ self.provider_variant = provider_variant
23
+ self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT)
24
+ log_info(f"🔌 SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout)
25
+
26
+ async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
27
+ """Generate response with improved error handling and streaming support"""
28
+ headers = {
29
+ "Authorization": f"Bearer {self.spark_token}",
30
+ "Content-Type": "application/json"
31
+ }
32
+
33
+ # Build context messages
34
+ messages = []
35
+ if system_prompt:
36
+ messages.append({
37
+ "role": "system",
38
+ "content": system_prompt
39
+ })
40
+
41
+ for msg in context[-10:]: # Last 10 messages for context
42
+ messages.append({
43
+ "role": msg.get("role", "user"),
44
+ "content": msg.get("content", "")
45
+ })
46
+
47
+ messages.append({
48
+ "role": "user",
49
+ "content": user_input
50
+ })
51
+
52
+ payload = {
53
+ "messages": messages,
54
+ "mode": self.provider_variant,
55
+ "max_tokens": self.settings.get("max_tokens", 2048),
56
+ "temperature": self.settings.get("temperature", 0.7),
57
+ "stream": False # For now, no streaming
58
+ }
59
+
60
+ try:
61
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
62
+ with LogTimer(f"Spark LLM request"):
63
+ response = await client.post(
64
+ f"{self.spark_endpoint}/generate",
65
+ json=payload,
66
+ headers=headers
67
+ )
68
+
69
+ # Check for rate limiting
70
+ if response.status_code == 429:
71
+ retry_after = response.headers.get("Retry-After", "60")
72
+ log_warning(f"Rate limited by Spark", retry_after=retry_after)
73
+ raise httpx.HTTPStatusError(
74
+ f"Rate limited. Retry after {retry_after}s",
75
+ request=response.request,
76
+ response=response
77
+ )
78
+
79
+ response.raise_for_status()
80
+ result = response.json()
81
+
82
+ # Extract response
83
+ content = result.get("model_answer", "")
84
+
85
+ # Check response length
86
+ if len(content) > MAX_RESPONSE_LENGTH:
87
+ log_warning(f"Response exceeded max length, truncating",
88
+ original_length=len(content),
89
+ max_length=MAX_RESPONSE_LENGTH)
90
+ content = content[:MAX_RESPONSE_LENGTH] + "..."
91
+
92
+ return content
93
+
94
+ except httpx.TimeoutException:
95
+ log_error(f"Spark request timed out", timeout=self.timeout)
96
+ raise
97
+ except httpx.HTTPStatusError as e:
98
+ log_error(f"Spark HTTP error",
99
+ status_code=e.response.status_code,
100
+ response=e.response.text[:500])
101
+ raise
102
+ except Exception as e:
103
+ log_error("Spark unexpected error", error=str(e))
104
+ raise
105
+
106
+ def get_provider_name(self) -> str:
107
+ return f"spark-{self.provider_variant}"
108
+
109
+ def get_model_info(self) -> Dict[str, Any]:
110
+ return {
111
+ "provider": "spark",
112
+ "variant": self.provider_variant,
113
+ "endpoint": self.spark_endpoint,
114
+ "max_tokens": self.settings.get("max_tokens", 2048),
115
+ "temperature": self.settings.get("temperature", 0.7)
116
+ }
llm/llm_startup.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flare – LLM startup notifier (Refactored)
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
+ Projeler açılırken LLM provider'a startup çağrısı yapar.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import threading
9
+ import asyncio
10
+ from typing import Dict, Any
11
+ from logger import log_info, log_error, log_warning, log_debug
12
+ from config_provider import ConfigProvider, ProjectConfig, VersionConfig
13
+ from llm_factory import LLMFactory
14
+
15
+ def _select_live_version(p: ProjectConfig) -> VersionConfig | None:
16
+ """Yayınlanmış en güncel versiyonu getir."""
17
+ published = [v for v in p.versions if v.published]
18
+ return max(published, key=lambda v: v.no) if published else None
19
+
20
+ async def notify_startup_async():
21
+ """Notify LLM provider about project startups (async version)"""
22
+ cfg = ConfigProvider.get()
23
+
24
+ # Check if LLM provider requires repo info
25
+ llm_provider_def = cfg.global_config.get_provider_config(
26
+ "llm",
27
+ cfg.global_config.llm_provider.name
28
+ )
29
+
30
+ if not llm_provider_def or not llm_provider_def.requires_repo_info:
31
+ log_info(f"ℹ️ LLM provider '{cfg.global_config.llm_provider.name}' does not require startup notification")
32
+ return
33
+
34
+ # Create LLM provider instance
35
+ try:
36
+ llm_provider = LLMFactory.create_provider()
37
+ except Exception as e:
38
+ log_error("❌ Failed to create LLM provider for startup", e)
39
+ return
40
+
41
+ # Notify for each enabled project
42
+ enabled_projects = [p for p in cfg.projects if p.enabled and not getattr(p, 'deleted', False)]
43
+
44
+ if not enabled_projects:
45
+ log_info("ℹ️ No enabled projects found for startup notification")
46
+ return
47
+
48
+ for project in enabled_projects:
49
+ version = _select_live_version(project)
50
+ if not version:
51
+ log_info(f"⚠️ No published version found for project '{project.name}', skipping startup")
52
+ continue
53
+
54
+ # Build project config - version.id yerine version.no kullan
55
+ project_config = {
56
+ "name": project.name,
57
+ "version_no": version.no, # version_id yerine version_no
58
+ "repo_id": version.llm.repo_id,
59
+ "generation_config": version.llm.generation_config,
60
+ "use_fine_tune": version.llm.use_fine_tune,
61
+ "fine_tune_zip": version.llm.fine_tune_zip
62
+ }
63
+
64
+ try:
65
+ log_info(f"🚀 Notifying LLM provider startup for project '{project.name}'...")
66
+ success = await llm_provider.startup(project_config)
67
+
68
+ if success:
69
+ log_info(f"✅ LLM provider acknowledged startup for '{project.name}'")
70
+ else:
71
+ log_info(f"⚠️ LLM provider startup failed for '{project.name}'")
72
+
73
+ except Exception as e:
74
+ log_error(f"❌ Error during startup notification for '{project.name}'", e)
75
+
76
+ def notify_startup():
77
+ """Synchronous wrapper for async startup notification"""
78
+ # Create new event loop for thread
79
+ loop = asyncio.new_event_loop()
80
+ asyncio.set_event_loop(loop)
81
+
82
+ try:
83
+ loop.run_until_complete(notify_startup_async())
84
+ finally:
85
+ loop.close()
86
+
87
+ def run_in_thread():
88
+ """Start startup notification in background thread"""
89
+ cfg = ConfigProvider.get()
90
+
91
+ # Check if provider requires startup
92
+ llm_provider_def = cfg.global_config.get_provider_config(
93
+ "llm",
94
+ cfg.global_config.llm_provider.name
95
+ )
96
+
97
+ if not llm_provider_def or not llm_provider_def.requires_repo_info:
98
+ log_info(f"🤖 {cfg.global_config.llm_provider.name} - Startup notification not required")
99
+ return
100
+
101
+ log_info("🚀 Starting LLM provider startup notification thread...")
102
+ threading.Thread(target=notify_startup, daemon=True).start()