Spaces:
Sleeping
Sleeping
Upload multit2i.py
Browse files- multit2i.py +15 -4
multit2i.py
CHANGED
@@ -143,20 +143,23 @@ def save_gallery(image_path: str | None, images: list[tuple] | None):
|
|
143 |
|
144 |
# https://github.com/gradio-app/gradio/blob/main/gradio/external.py
|
145 |
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
|
146 |
-
|
|
|
147 |
import httpx
|
148 |
import huggingface_hub
|
149 |
-
from gradio.exceptions import ModelNotFoundError
|
150 |
model_url = f"https://huggingface.co/{model_name}"
|
151 |
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
|
152 |
print(f"Fetching model from: {model_url}")
|
153 |
|
154 |
-
headers = {"Authorization": f"Bearer {hf_token}"}
|
155 |
response = httpx.request("GET", api_url, headers=headers)
|
156 |
if response.status_code != 200:
|
157 |
raise ModelNotFoundError(
|
158 |
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
|
159 |
)
|
|
|
|
|
160 |
headers["X-Wait-For-Model"] = "true"
|
161 |
client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
|
162 |
token=hf_token, timeout=server_timeout)
|
@@ -165,7 +168,14 @@ def load_from_model(model_name: str, hf_token: str = None):
|
|
165 |
fn = client.text_to_image
|
166 |
|
167 |
def query_huggingface_inference_endpoints(*data, **kwargs):
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
interface_info = {
|
171 |
"fn": query_huggingface_inference_endpoints,
|
@@ -370,6 +380,7 @@ def infer_body(client: InferenceClient | gr.Interface | object, prompt: str, neg
|
|
370 |
elif isinstance(client, gr.Interface):
|
371 |
image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
|
372 |
else: return None
|
|
|
373 |
image.save(png_path)
|
374 |
return str(Path(png_path).resolve())
|
375 |
except Exception as e:
|
|
|
143 |
|
144 |
# https://github.com/gradio-app/gradio/blob/main/gradio/external.py
|
145 |
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
|
146 |
+
from typing import Literal
|
147 |
+
def load_from_model(model_name: str, hf_token: str | Literal[False] | None = None):
|
148 |
import httpx
|
149 |
import huggingface_hub
|
150 |
+
from gradio.exceptions import ModelNotFoundError, TooManyRequestsError
|
151 |
model_url = f"https://huggingface.co/{model_name}"
|
152 |
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
|
153 |
print(f"Fetching model from: {model_url}")
|
154 |
|
155 |
+
headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"})
|
156 |
response = httpx.request("GET", api_url, headers=headers)
|
157 |
if response.status_code != 200:
|
158 |
raise ModelNotFoundError(
|
159 |
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
|
160 |
)
|
161 |
+
p = response.json().get("pipeline_tag")
|
162 |
+
if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.")
|
163 |
headers["X-Wait-For-Model"] = "true"
|
164 |
client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
|
165 |
token=hf_token, timeout=server_timeout)
|
|
|
168 |
fn = client.text_to_image
|
169 |
|
170 |
def query_huggingface_inference_endpoints(*data, **kwargs):
|
171 |
+
try:
|
172 |
+
data = fn(*data, **kwargs) # type: ignore
|
173 |
+
except huggingface_hub.utils.HfHubHTTPError as e:
|
174 |
+
if "429" in str(e):
|
175 |
+
raise TooManyRequestsError() from e
|
176 |
+
except Exception as e:
|
177 |
+
raise Exception(e)
|
178 |
+
return data
|
179 |
|
180 |
interface_info = {
|
181 |
"fn": query_huggingface_inference_endpoints,
|
|
|
380 |
elif isinstance(client, gr.Interface):
|
381 |
image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
|
382 |
else: return None
|
383 |
+
if isinstance(image, tuple): return None
|
384 |
image.save(png_path)
|
385 |
return str(Path(png_path).resolve())
|
386 |
except Exception as e:
|