Geek7 commited on
Commit
e089372
·
verified ·
1 Parent(s): fbed9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -41
app.py CHANGED
@@ -1,23 +1,18 @@
1
- from flask import Flask, request, jsonify, send_file
2
- from flask_cors import CORS
 
 
3
  import asyncio
4
- from threading import RLock
5
- from gradio_client import Client
6
- from all_models import models # Import your models
7
- from externalmod import gr_Interface_load # Import the model loading function
8
  import os
9
-
10
- app = Flask(__name__)
11
- CORS(app) # Enable CORS for all routes
12
 
13
  lock = RLock()
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
- client = Client("Geek7/mdztxi2")
16
 
17
- # Load models using gr_Interface_load
18
  def load_fn(models):
19
  global models_load
20
  models_load = {}
 
21
  for model in models:
22
  if model not in models_load.keys():
23
  try:
@@ -27,42 +22,39 @@ def load_fn(models):
27
  m = gr.Interface(lambda: None, ['text'], ['image'])
28
  models_load.update({model: m})
29
 
30
- # Load the models
31
  load_fn(models)
32
 
33
  num_models = 6
34
  MAX_SEED = 3999999999
35
- inference_timeout = 600 # 10 minutes
 
36
 
37
- # Asynchronous inference function
38
- async def async_infer(model_str, prompt, seed=1, timeout=inference_timeout):
 
 
39
  try:
40
- result = await asyncio.to_thread(client.predict, model_str=model_str, prompt=prompt, seed=seed, api_name="/predict")
 
 
 
 
 
 
 
41
  with lock:
42
  png_path = "image.png"
43
- result.save(png_path) # Save the image to a file
44
  return png_path
45
- except Exception as e:
46
- print(f"Error during inference: {e}")
47
- return None
48
-
49
- # Define the endpoint for asynchronous inference
50
- @app.route('/async_infer', methods=['POST'])
51
- def async_infer_endpoint():
52
- data = request.get_json()
53
- model_str = data['model_str']
54
- prompt = data['prompt']
55
- seed = data.get('seed', 1) # Default seed value is 1
56
-
57
- # Make a prediction request
58
- try:
59
- result_path = asyncio.run(async_infer(model_str, prompt, seed))
60
- if result_path:
61
- return send_file(result_path, mimetype='image/png') # Send the generated image file
62
- else:
63
- return jsonify({"error": "Image generation failed."}), 500
64
- except Exception as e:
65
- return jsonify({"error": str(e)}), 500
66
-
67
- if __name__ == '__main__':
68
- app.run(debug=True)
 
1
+ import gradio as gr
2
+ from random import randint
3
+ from all_models import models
4
+ from externalmod import gr_Interface_load
5
  import asyncio
 
 
 
 
6
  import os
7
+ from threading import RLock
 
 
8
 
9
  lock = RLock()
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
11
 
 
12
  def load_fn(models):
13
  global models_load
14
  models_load = {}
15
+
16
  for model in models:
17
  if model not in models_load.keys():
18
  try:
 
22
  m = gr.Interface(lambda: None, ['text'], ['image'])
23
  models_load.update({model: m})
24
 
 
25
  load_fn(models)
26
 
27
  num_models = 6
28
  MAX_SEED = 3999999999
29
+ default_models = models[:num_models]
30
+ inference_timeout = 600
31
 
32
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
+ kwargs = {"seed": seed}
34
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
+ await asyncio.sleep(0)
36
  try:
37
+ result = await asyncio.wait_for(task, timeout=timeout)
38
+ except (Exception, asyncio.TimeoutError) as e:
39
+ print(e)
40
+ print(f"Task timed out: {model_str}")
41
+ if not task.done():
42
+ task.cancel()
43
+ result = None
44
+ if task.done() and result is not None:
45
  with lock:
46
  png_path = "image.png"
47
+ result.save(png_path)
48
  return png_path
49
+ return None
50
+
51
+ # Expose Gradio API
52
+ def generate_api(model_str, prompt, seed=1):
53
+ result = asyncio.run(infer(model_str, prompt, seed))
54
+ if result:
55
+ return result # Path to generated image
56
+ return None
57
+
58
+ # Launch Gradio API without frontend
59
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
+ iface.launch(show_api=True, share=True)