jarguello76 commited on
Commit
52a1284
·
verified ·
1 Parent(s): 7de7748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -20,14 +20,27 @@ os.environ["HUGGINGFACE_API_KEY"] = hf_token
20
 
21
 
22
  # Define the HuggingFaceInferenceWrapper class correctly
 
 
23
  class HuggingFaceInferenceWrapper:
24
  def __init__(self, inference_api):
25
  self.inference_api = inference_api
26
 
27
  def generate(self, prompt: str, **kwargs) -> str:
28
- # Call inference API - returns string directly
29
- response = self.inference_api(inputs=prompt)
30
- return response.strip()
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
20
 
21
 
22
  # Define the HuggingFaceInferenceWrapper class correctly
23
+ import json
24
+
25
  class HuggingFaceInferenceWrapper:
26
  def __init__(self, inference_api):
27
  self.inference_api = inference_api
28
 
29
  def generate(self, prompt: str, **kwargs) -> str:
30
+ # Request raw response and handle JSON decoding
31
+ response = self.inference_api(inputs=prompt, raw_response=True)
32
+ result_json = response.json()
33
+
34
+ # Handle both plain text and chat-style model formats
35
+ if isinstance(result_json, dict):
36
+ if "generated_text" in result_json:
37
+ return result_json["generated_text"].strip()
38
+ elif "outputs" in result_json and isinstance(result_json["outputs"], str):
39
+ return result_json["outputs"].strip()
40
+
41
+ # If unknown format, fallback
42
+ return str(result_json)
43
+
44
 
45
 
46
  def run_and_submit_all(profile: gr.OAuthProfile | None):