Geek7 commited on
Commit
fbed9e3
·
verified ·
1 Parent(s): 7151510

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -44
app.py CHANGED
@@ -1,78 +1,68 @@
1
  from flask import Flask, request, jsonify, send_file
2
- from flask_cors import CORS # For enabling CORS
3
- import gradio as gr
4
  import asyncio
5
- import os
6
  from threading import RLock
7
  from gradio_client import Client
8
- from all_models import models # Your model import
9
- from externalmod import gr_Interface_load # Your custom model loader
 
10
 
11
  app = Flask(__name__)
12
  CORS(app) # Enable CORS for all routes
13
 
14
- # Gradio Client Initialization
15
- client = Client("Geek7/mdztxi2")
16
-
17
  lock = RLock()
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
19
 
20
- # Model Loading Function
21
  def load_fn(models):
22
  global models_load
23
  models_load = {}
24
-
25
  for model in models:
26
  if model not in models_load.keys():
27
  try:
28
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
29
  except Exception as error:
30
  print(error)
31
- m = gr.Interface(lambda: None, ['text'], ['image']) # Fallback
32
  models_load.update({model: m})
33
 
 
34
  load_fn(models)
35
 
36
- # Async inference function to call Gradio model prediction
37
- async def infer(model_str, prompt, seed=1, timeout=600):
38
- kwargs = {"seed": seed}
39
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
- await asyncio.sleep(0)
 
41
  try:
42
- result = await asyncio.wait_for(task, timeout=timeout)
43
- except (Exception, asyncio.TimeoutError) as e:
44
- print(e)
45
- print(f"Task timed out: {model_str}")
46
- if not task.done():
47
- task.cancel()
48
- result = None
49
- if task.done() and result is not None:
50
  with lock:
51
  png_path = "image.png"
52
- result.save(png_path)
53
  return png_path
54
- return None
 
 
55
 
56
- # API endpoint for generating an image and sending it as a file
57
- @app.route('/generate-image', methods=['POST'])
58
- def generate_image():
59
  data = request.get_json()
60
- model_str = data.get('model_str')
61
- prompt = data.get('prompt')
62
- seed = data.get('seed', 1)
63
-
64
- # Validate input
65
- if not model_str or not prompt:
66
- return jsonify({"error": "Model string and prompt are required."}), 400
67
 
68
- # Generate image using the async inference function
69
- result_path = asyncio.run(infer(model_str, prompt, seed))
70
-
71
- if result_path:
72
- # Return the image file using send_file
73
- return send_file(result_path, mimetype='image/png')
74
- else:
75
- return jsonify({"error": "Image generation failed."}), 500
 
76
 
77
  if __name__ == '__main__':
78
  app.run(debug=True)
 
1
  from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
 
3
  import asyncio
 
4
  from threading import RLock
5
  from gradio_client import Client
6
+ from all_models import models # Import your models
7
+ from externalmod import gr_Interface_load # Import the model loading function
8
+ import os
9
 
10
  app = Flask(__name__)
11
  CORS(app) # Enable CORS for all routes
12
 
 
 
 
13
  lock = RLock()
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ client = Client("Geek7/mdztxi2")
16
 
17
+ # Load models using gr_Interface_load
18
  def load_fn(models):
19
  global models_load
20
  models_load = {}
 
21
  for model in models:
22
  if model not in models_load.keys():
23
  try:
24
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
25
  except Exception as error:
26
  print(error)
27
+ m = gr.Interface(lambda: None, ['text'], ['image'])
28
  models_load.update({model: m})
29
 
30
+ # Load the models
31
  load_fn(models)
32
 
33
+ num_models = 6
34
+ MAX_SEED = 3999999999
35
+ inference_timeout = 600 # 10 minutes
36
+
37
+ # Asynchronous inference function
38
+ async def async_infer(model_str, prompt, seed=1, timeout=inference_timeout):
39
  try:
40
+ result = await asyncio.to_thread(client.predict, model_str=model_str, prompt=prompt, seed=seed, api_name="/predict")
 
 
 
 
 
 
 
41
  with lock:
42
  png_path = "image.png"
43
+ result.save(png_path) # Save the image to a file
44
  return png_path
45
+ except Exception as e:
46
+ print(f"Error during inference: {e}")
47
+ return None
48
 
49
+ # Define the endpoint for asynchronous inference
50
+ @app.route('/async_infer', methods=['POST'])
51
+ def async_infer_endpoint():
52
  data = request.get_json()
53
+ model_str = data['model_str']
54
+ prompt = data['prompt']
55
+ seed = data.get('seed', 1) # Default seed value is 1
 
 
 
 
56
 
57
+ # Make a prediction request
58
+ try:
59
+ result_path = asyncio.run(async_infer(model_str, prompt, seed))
60
+ if result_path:
61
+ return send_file(result_path, mimetype='image/png') # Send the generated image file
62
+ else:
63
+ return jsonify({"error": "Image generation failed."}), 500
64
+ except Exception as e:
65
+ return jsonify({"error": str(e)}), 500
66
 
67
  if __name__ == '__main__':
68
  app.run(debug=True)