mgbam commited on
Commit
20ee71d
·
verified ·
1 Parent(s): b1d6425

Update api_clients.py

Browse files
Files changed (1) hide show
  1. api_clients.py +14 -19
api_clients.py CHANGED
@@ -21,26 +21,15 @@ from web_extraction import extract_website_content, enhance_query_with_search
21
 
22
  # HF Inference Client
23
  HF_TOKEN = os.getenv('HF_TOKEN')
24
- GROQ_API_KEY = os.getenv('GROQ_API_KEY')
25
- FIREWORKS_API_KEY = os.getenv('FIREWORKS_API_KEY')
26
 
27
  def get_inference_client(model_id):
28
- """Return an InferenceClient configured for Hugging Face, Groq, or Fireworks AI."""
29
- if model_id == "moonshotai/Kimi-K2-Instruct":
30
- return InferenceClient(
31
- base_url="https://api.groq.com/openai/v1",
32
- api_key=GROQ_API_KEY
33
- )
34
- elif model_id.startswith("fireworks/"):
35
- return InferenceClient(
36
- base_url="https://api.fireworks.ai/inference/v1",
37
- api_key=FIREWORKS_API_KEY
38
- )
39
- else:
40
- return InferenceClient(
41
- model=model_id,
42
- api_key=HF_TOKEN
43
- )
44
 
45
  # Tavily Search Client
46
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
@@ -52,7 +41,7 @@ if TAVILY_API_KEY:
52
  print(f"Failed to initialize Tavily client: {e}")
53
  tavily_client = None
54
 
55
- async def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[List[Tuple[str, str]]], _current_model: Dict, enable_search: bool = False, language: str = "html"):
56
  if query is None:
57
  query = ''
58
  if _history is None:
@@ -65,6 +54,7 @@ async def generation_code(query: Optional[str], image: Optional[gr.Image], file:
65
  last_assistant_msg = _history[-1][1] if len(_history) > 0 else ""
66
  if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
67
  has_existing_html = True
 
68
 
69
  # Choose system prompt based on context
70
  if has_existing_html:
@@ -81,12 +71,14 @@ async def generation_code(query: Optional[str], image: Optional[gr.Image], file:
81
 
82
  # Extract file text and append to query if file is present
83
  file_text = ""
 
84
  if file:
85
  file_text = extract_text_from_file(file)
86
  if file_text:
87
  file_text = file_text[:5000] # Limit to 5000 chars for prompt size
88
  query = f"{query}\n\n[Reference file content below]\n{file_text}"
89
 
 
90
  # Extract website content and append to query if website URL is present
91
  website_text = ""
92
  if website_url and website_url.strip():
@@ -105,6 +97,7 @@ Since I couldn't extract the website content, please provide additional details
105
  This will help me create a better design for you."""
106
  query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
107
 
 
108
  # Enhance query with search if enabled
109
  enhanced_query = enhance_query_with_search(query, enable_search)
110
 
@@ -115,6 +108,7 @@ This will help me create a better design for you."""
115
  messages.append(create_multimodal_message(enhanced_query, image))
116
  else:
117
  messages.append({'role': 'user', 'content': enhanced_query})
 
118
  try:
119
  completion = client.chat.completions.create(
120
  model=_current_model["id"], # Corrected this line
@@ -122,6 +116,7 @@ This will help me create a better design for you."""
122
  stream=True,
123
  max_tokens=5000
124
  )
 
125
  content = ""
126
  for chunk in completion:
127
  if chunk.choices[0].delta.content:
 
21
 
22
  # HF Inference Client
23
  HF_TOKEN = os.getenv('HF_TOKEN')
 
 
24
 
25
  def get_inference_client(model_id):
26
+ """Return an InferenceClient with provider based on model_id."""
27
+ provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
28
+ return InferenceClient(
29
+ provider=provider,
30
+ api_key=HF_TOKEN,
31
+ bill_to="huggingface"
32
+ )
 
 
 
 
 
 
 
 
 
33
 
34
  # Tavily Search Client
35
  TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
 
41
  print(f"Failed to initialize Tavily client: {e}")
42
  tavily_client = None
43
 
44
+ async def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[List[Tuple[str, str]]], _current_model: Dict, enable_search: bool = False, language: str = "html", progress=gr.Progress(track_tqdm=True)):
45
  if query is None:
46
  query = ''
47
  if _history is None:
 
54
  last_assistant_msg = _history[-1][1] if len(_history) > 0 else ""
55
  if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
56
  has_existing_html = True
57
+ progress(0, desc="Initializing...")
58
 
59
  # Choose system prompt based on context
60
  if has_existing_html:
 
71
 
72
  # Extract file text and append to query if file is present
73
  file_text = ""
74
+ progress(0.1, desc="Processing file...")
75
  if file:
76
  file_text = extract_text_from_file(file)
77
  if file_text:
78
  file_text = file_text[:5000] # Limit to 5000 chars for prompt size
79
  query = f"{query}\n\n[Reference file content below]\n{file_text}"
80
 
81
+ progress(0.2, desc="Extracting website content...")
82
  # Extract website content and append to query if website URL is present
83
  website_text = ""
84
  if website_url and website_url.strip():
 
97
  This will help me create a better design for you."""
98
  query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
99
 
100
+ progress(0.4, desc="Performing web search...")
101
  # Enhance query with search if enabled
102
  enhanced_query = enhance_query_with_search(query, enable_search)
103
 
 
108
  messages.append(create_multimodal_message(enhanced_query, image))
109
  else:
110
  messages.append({'role': 'user', 'content': enhanced_query})
111
+ progress(0.5, desc="Generating code with AI model...")
112
  try:
113
  completion = client.chat.completions.create(
114
  model=_current_model["id"], # Corrected this line
 
116
  stream=True,
117
  max_tokens=5000
118
  )
119
+ progress(0.6, desc="Streaming response...")
120
  content = ""
121
  for chunk in completion:
122
  if chunk.choices[0].delta.content: