mgbam commited on
Commit
0466efc
·
verified ·
1 Parent(s): 46c562f

Update api_clients.py

Browse files
Files changed (1) hide show
  1. api_clients.py +9 -20
api_clients.py CHANGED
@@ -21,26 +21,15 @@ from web_extraction import extract_website_content, enhance_query_with_search
21
 
22
  # HF Inference Client
23
  HF_TOKEN = os.getenv('HF_TOKEN')
24
- GROQ_API_KEY = os.getenv('GROQ_API_KEY')
25
- FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY')
26
 
27
  def get_inference_client(model_id):
28
- """Return an InferenceClient configured for Hugging Face, Groq, or Fireworks AI."""
29
- if model_id == "moonshotai/Kimi-K2-Instruct":
30
- return InferenceClient(
31
- base_url="https://api.groq.com/openai/v1",
32
- api_key=GROQ_API_KEY
33
- )
34
- elif model_id.startswith("fireworks/"):
35
- return InferenceClient(
36
- base_url="https://api.fireworks.ai/inference/v1",
37
- api_key=FIREWORKS_API_KEY
38
- )
39
- else:
40
- return InferenceClient(
41
- model=model_id,
42
- api_key=HF_TOKEN
43
- )
44
 
45
  # Tavily Search Client
46
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
@@ -52,7 +41,7 @@ if TAVILY_API_KEY:
52
  print(f"Failed to initialize Tavily client: {e}")
53
  tavily_client = None
54
 
55
- def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[List[Tuple[str, str]]], _current_model: Dict, enable_search: bool = False, language: str = "html"):
56
  if query is None:
57
  query = ''
58
  if _history is None:
@@ -117,7 +106,7 @@ This will help me create a better design for you."""
117
  messages.append({'role': 'user', 'content': enhanced_query})
118
  try:
119
  completion = client.chat.completions.create(
120
- model=_current_model["id"],
121
  messages=messages,
122
  stream=True,
123
  max_tokens=5000
 
21
 
22
  # HF Inference Client
23
  HF_TOKEN = os.getenv('HF_TOKEN')
 
 
24
 
25
  def get_inference_client(model_id):
26
+ """Return an InferenceClient with provider based on model_id."""
27
+ provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
28
+ return InferenceClient(
29
+ provider=provider,
30
+ api_key=HF_TOKEN,
31
+ bill_to="huggingface"
32
+ )
 
 
 
 
 
 
 
 
 
33
 
34
  # Tavily Search Client
35
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
 
41
  print(f"Failed to initialize Tavily client: {e}")
42
  tavily_client = None
43
 
44
+ async def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[List[Tuple[str, str]]], _current_model: Dict, enable_search: bool = False, language: str = "html"):
45
  if query is None:
46
  query = ''
47
  if _history is None:
 
106
  messages.append({'role': 'user', 'content': enhanced_query})
107
  try:
108
  completion = client.chat.completions.create(
109
+ model=_current_model["id"], # Corrected this line
110
  messages=messages,
111
  stream=True,
112
  max_tokens=5000