mgbam commited on
Commit
6756974
·
verified ·
1 Parent(s): 8f9bbf1

Update services.py

Browse files
Files changed (1) hide show
  1. services.py +104 -47
services.py CHANGED
@@ -3,9 +3,12 @@
3
  """
4
  Manages interactions with external services like LLM providers and web search APIs.
5
  This module has been refactored to support multiple LLM providers:
6
- - Hugging Face (for standard and multimodal models)
7
- - Groq (for high-speed inference)
8
  - Fireworks AI
 
 
 
9
  """
10
  import os
11
  import logging
@@ -18,16 +21,23 @@ from huggingface_hub import InferenceClient
18
  from tavily import TavilyClient
19
  from groq import Groq
20
  import fireworks.client as Fireworks
 
 
 
21
 
22
  # --- Setup Logging & Environment ---
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
  load_dotenv()
25
 
26
- # --- API Keys ---
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
29
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
30
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
 
 
 
 
31
 
32
  # --- Type Definitions ---
33
  Messages = List[Dict[str, Any]]
@@ -36,78 +46,125 @@ class LLMService:
36
  """A multi-provider wrapper for LLM Inference APIs."""
37
 
38
  def __init__(self):
39
- # Initialize clients if their API keys are available
40
  self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
41
  self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
42
- self.fireworks_client = Fireworks if FIREWORKS_API_KEY else None
43
- if self.fireworks_client:
44
- self.fireworks_client.api_key = FIREWORKS_API_KEY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def generate_code_stream(
47
- self, model_id: str, messages: Messages, max_tokens: int = 8000
48
  ) -> Generator[str, None, None]:
49
  """
50
  Streams code generation, dispatching to the correct provider based on model_id.
51
- The model_id format is 'provider/model-name' or a full HF model_id.
52
  """
53
- provider = "huggingface" # Default provider
54
- model_name = model_id
55
-
56
- if '/' in model_id:
57
- parts = model_id.split('/', 1)
58
- if parts[0] in ['groq', 'fireworks', 'huggingface']:
59
- provider = parts[0]
60
- model_name = parts[1]
61
-
62
  logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
63
 
64
  try:
65
- # --- Groq Provider ---
66
- if provider == 'groq':
67
- if not self.groq_client:
68
- raise ValueError("Groq API key is not configured.")
69
- stream = self.groq_client.chat.completions.create(
70
- model=model_name, messages=messages, stream=True, max_tokens=max_tokens
71
- )
 
 
 
 
 
 
 
 
 
 
 
72
  for chunk in stream:
73
- if chunk.choices[0].delta.content:
74
  yield chunk.choices[0].delta.content
75
 
76
- # --- Fireworks AI Provider ---
77
- elif provider == 'fireworks':
78
- if not self.fireworks_client:
79
- raise ValueError("Fireworks AI API key is not configured.")
80
- stream = self.fireworks_client.ChatCompletion.create(
81
- model=model_name, messages=messages, stream=True, max_tokens=max_tokens
82
- )
 
 
 
 
 
 
 
 
 
83
  for chunk in stream:
84
  if chunk.choices[0].delta.content:
85
  yield chunk.choices[0].delta.content
86
 
87
- # --- Hugging Face Provider (Default) ---
88
  else:
89
- if not self.hf_client:
90
- raise ValueError("Hugging Face API token is not configured.")
91
- # For HF, the model_name is the full original model_id
92
- stream = self.hf_client.chat_completion(
93
- model=model_name, messages=messages, stream=True, max_tokens=max_tokens
94
- )
95
- for chunk in stream:
96
- yield chunk.choices[0].delta.content
97
 
98
  except Exception as e:
99
  logging.error(f"LLM API Error with provider {provider}: {e}")
100
  yield f"Error from {provider.capitalize()}: {str(e)}"
101
 
102
-
103
  class SearchService:
104
- # (This class remains unchanged)
105
  def __init__(self, api_key: str = TAVILY_API_KEY):
106
- # ... existing code ...
 
 
 
 
 
 
 
 
 
107
  def is_available(self) -> bool:
108
- # ... existing code ...
 
 
109
  def search(self, query: str, max_results: int = 5) -> str:
110
- # ... existing code ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # --- Singleton Instances ---
113
  llm_service = LLMService()
 
3
  """
4
  Manages interactions with external services like LLM providers and web search APIs.
5
  This module has been refactored to support multiple LLM providers:
6
+ - Hugging Face
7
+ - Groq
8
  - Fireworks AI
9
+ - OpenAI
10
+ - Google Gemini
11
+ - DeepSeek (Direct API)
12
  """
13
  import os
14
  import logging
 
21
  from tavily import TavilyClient
22
  from groq import Groq
23
  import fireworks.client as Fireworks
24
+ import openai
25
+ import google.generativeai as genai
26
+ from deepseek import OpenaiClient as DeepSeekClient
27
 
28
  # --- Setup Logging & Environment ---
29
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
30
  load_dotenv()
31
 
32
+ # --- API Keys from .env ---
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
35
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
36
  FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
37
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
38
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
39
+ DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
40
+
41
 
42
  # --- Type Definitions ---
43
  Messages = List[Dict[str, Any]]
 
46
  """A multi-provider wrapper for LLM Inference APIs."""
47
 
48
  def __init__(self):
49
+ # Initialize clients only if their API keys are available
50
  self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
51
  self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
52
+ self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
53
+ self.deepseek_client = DeepSeekClient(api_key=DEEPSEEK_API_KEY) if DEEPSEEK_API_KEY else None
54
+
55
+ if FIREWORKS_API_KEY:
56
+ Fireworks.api_key = FIREWORKS_API_KEY
57
+ self.fireworks_client = Fireworks
58
+ else:
59
+ self.fireworks_client = None
60
+
61
+ if GEMINI_API_KEY:
62
+ genai.configure(api_key=GEMINI_API_KEY)
63
+ self.gemini_model = genai.GenerativeModel('gemini-1.5-pro-latest')
64
+ else:
65
+ self.gemini_model = None
66
+
67
+ def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
68
+ """Gemini requires a slightly different message format."""
69
+ gemini_messages = []
70
+ for msg in messages:
71
+ # Gemini uses 'model' for assistant role
72
+ role = 'model' if msg['role'] == 'assistant' else 'user'
73
+ gemini_messages.append({'role': role, 'parts': [msg['content']]})
74
+ return gemini_messages
75
 
76
  def generate_code_stream(
77
+ self, model_id: str, messages: Messages, max_tokens: int = 8192
78
  ) -> Generator[str, None, None]:
79
  """
80
  Streams code generation, dispatching to the correct provider based on model_id.
 
81
  """
82
+ provider, model_name = model_id.split('/', 1)
 
 
 
 
 
 
 
 
83
  logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
84
 
85
  try:
86
+ # --- OpenAI, Groq, DeepSeek, Fireworks (OpenAI-compatible) ---
87
+ if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
88
+ client_map = {
89
+ 'openai': self.openai_client,
90
+ 'groq': self.groq_client,
91
+ 'deepseek': self.deepseek_client,
92
+ 'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None,
93
+ }
94
+ client = client_map.get(provider)
95
+ if not client:
96
+ raise ValueError(f"{provider.capitalize()} API key not configured.")
97
+
98
+ # Fireworks has a slightly different call signature
99
+ if provider == 'fireworks':
100
+ stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
101
+ else:
102
+ stream = client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
103
+
104
  for chunk in stream:
105
+ if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
106
  yield chunk.choices[0].delta.content
107
 
108
+ # --- Google Gemini ---
109
+ elif provider == 'gemini':
110
+ if not self.gemini_model:
111
+ raise ValueError("Gemini API key not configured.")
112
+ gemini_messages = self._prepare_messages_for_gemini(messages)
113
+ stream = self.gemini_model.generate_content(gemini_messages, stream=True)
114
+ for chunk in stream:
115
+ yield chunk.text
116
+
117
+ # --- Hugging Face ---
118
+ elif provider == 'huggingface':
119
+ if not self.hf_client:
120
+ raise ValueError("Hugging Face API token not configured.")
121
+ # For HF, model_name is the rest of the ID, e.g., baidu/ERNIE...
122
+ hf_model_id = model_id.split('/', 1)[1]
123
+ stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
124
  for chunk in stream:
125
  if chunk.choices[0].delta.content:
126
  yield chunk.choices[0].delta.content
127
 
 
128
  else:
129
+ raise ValueError(f"Unknown provider: {provider}")
 
 
 
 
 
 
 
130
 
131
  except Exception as e:
132
  logging.error(f"LLM API Error with provider {provider}: {e}")
133
  yield f"Error from {provider.capitalize()}: {str(e)}"
134
 
 
135
  class SearchService:
136
+ """A wrapper for the Tavily Search API."""
137
  def __init__(self, api_key: str = TAVILY_API_KEY):
138
+ if not api_key:
139
+ logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
140
+ self.client = None
141
+ else:
142
+ try:
143
+ self.client = TavilyClient(api_key=api_key)
144
+ except Exception as e:
145
+ logging.error(f"Failed to initialize Tavily client: {e}")
146
+ self.client = None
147
+
148
  def is_available(self) -> bool:
149
+ """Checks if the search service is configured and available."""
150
+ return self.client is not None
151
+
152
  def search(self, query: str, max_results: int = 5) -> str:
153
+ """Performs a web search and returns a formatted string of results."""
154
+ if not self.is_available():
155
+ return "Web search is not available."
156
+ try:
157
+ response = self.client.search(
158
+ query, search_depth="advanced", max_results=min(max(1, max_results), 10)
159
+ )
160
+ results = [
161
+ f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}"
162
+ for res in response.get('results', [])
163
+ ]
164
+ return "Web Search Results:\n\n" + "\n---\n".join(results) if results else "No search results found."
165
+ except Exception as e:
166
+ logging.error(f"Tavily search error: {e}")
167
+ return f"Search error: {str(e)}"
168
 
169
  # --- Singleton Instances ---
170
  llm_service = LLMService()