Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ from externalmod import gr_Interface_load
|
|
5 |
import asyncio
|
6 |
import os
|
7 |
from threading import RLock
|
|
|
|
|
8 |
|
9 |
lock = RLock()
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
@@ -29,7 +31,7 @@ MAX_SEED = 3999999999
|
|
29 |
default_models = models[:num_models]
|
30 |
inference_timeout = 600
|
31 |
|
32 |
-
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
|
33 |
kwargs = {"seed": seed}
|
34 |
task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
|
35 |
await asyncio.sleep(0)
|
@@ -48,6 +50,31 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
|
|
48 |
return png_path
|
49 |
return None
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Expose Gradio API
|
52 |
def generate_api(model_str, prompt, seed=1):
|
53 |
result = asyncio.run(infer(model_str, prompt, seed))
|
|
|
5 |
import asyncio
|
6 |
import os
|
7 |
from threading import RLock
|
8 |
+
from flask import send_file
|
9 |
+
|
10 |
|
11 |
lock = RLock()
|
12 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
31 |
default_models = models[:num_models]
|
32 |
inference_timeout = 600
|
33 |
|
34 |
+
#async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
|
35 |
kwargs = {"seed": seed}
|
36 |
task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
|
37 |
await asyncio.sleep(0)
|
|
|
50 |
return png_path
|
51 |
return None
|
52 |
|
53 |
+
|
54 |
+
|
55 |
+
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
|
56 |
+
kwargs = {"seed": seed}
|
57 |
+
task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
|
58 |
+
await asyncio.sleep(0)
|
59 |
+
|
60 |
+
try:
|
61 |
+
result = await asyncio.wait_for(task, timeout=timeout)
|
62 |
+
except (Exception, asyncio.TimeoutError) as e:
|
63 |
+
print(e)
|
64 |
+
print(f"Task timed out: {model_str}")
|
65 |
+
if not task.done():
|
66 |
+
task.cancel()
|
67 |
+
return None # Return None if the task fails or times out
|
68 |
+
|
69 |
+
if task.done() and result is not None:
|
70 |
+
with lock:
|
71 |
+
png_path = "image.png"
|
72 |
+
result.save(png_path) # Save the image to a file
|
73 |
+
return send_file(png_path, mimetype='image/png') # Directly use send_file here
|
74 |
+
|
75 |
+
return None # Return None if no result is generated
|
76 |
+
|
77 |
+
|
78 |
# Expose Gradio API
|
79 |
def generate_api(model_str, prompt, seed=1):
|
80 |
result = asyncio.run(infer(model_str, prompt, seed))
|