Geek7 commited on
Commit
5bc3f97
·
verified ·
1 Parent(s): 67f23ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -5,6 +5,8 @@ from externalmod import gr_Interface_load
5
  import asyncio
6
  import os
7
  from threading import RLock
 
 
8
 
9
  lock = RLock()
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -29,7 +31,7 @@ MAX_SEED = 3999999999
29
  default_models = models[:num_models]
30
  inference_timeout = 600
31
 
32
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
  kwargs = {"seed": seed}
34
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
  await asyncio.sleep(0)
@@ -48,6 +50,31 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
48
  return png_path
49
  return None
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Expose Gradio API
52
  def generate_api(model_str, prompt, seed=1):
53
  result = asyncio.run(infer(model_str, prompt, seed))
 
5
  import asyncio
6
  import os
7
  from threading import RLock
8
+ from flask import send_file
9
+
10
 
11
  lock = RLock()
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
31
  default_models = models[:num_models]
32
  inference_timeout = 600
33
 
34
+ #async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
35
  kwargs = {"seed": seed}
36
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
37
  await asyncio.sleep(0)
 
50
  return png_path
51
  return None
52
 
53
+
54
+
55
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
56
+ kwargs = {"seed": seed}
57
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
58
+ await asyncio.sleep(0)
59
+
60
+ try:
61
+ result = await asyncio.wait_for(task, timeout=timeout)
62
+ except (Exception, asyncio.TimeoutError) as e:
63
+ print(e)
64
+ print(f"Task timed out: {model_str}")
65
+ if not task.done():
66
+ task.cancel()
67
+ return None # Return None if the task fails or times out
68
+
69
+ if task.done() and result is not None:
70
+ with lock:
71
+ png_path = "image.png"
72
+ result.save(png_path) # Save the image to a file
73
+ return send_file(png_path, mimetype='image/png') # Directly use send_file here
74
+
75
+ return None # Return None if no result is generated
76
+
77
+
78
  # Expose Gradio API
79
  def generate_api(model_str, prompt, seed=1):
80
  result = asyncio.run(infer(model_str, prompt, seed))