Geek7 commited on
Commit
0a71ede
·
verified ·
1 Parent(s): 7044e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -28
app.py CHANGED
@@ -1,55 +1,95 @@
 
 
1
  from all_models import models
2
- from flask import Flask, request, jsonify, send_file
3
- from flask_cors import CORS
4
- from gradio_client import Client
5
  from externalmod import gr_Interface_load
 
6
  import os
 
 
 
 
 
 
7
 
 
8
  app = Flask(__name__)
9
  CORS(app) # Enable CORS for all routes
10
 
11
- HF_TOKEN = os.environ.get("HF_TOKEN") # Load the Hugging Face token if needed
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Function to load the model using gr_Interface_load
14
- def load_model(model_name):
15
- try:
16
- model_interface = gr_Interface_load(f'models/{model_name}', hf_token=HF_TOKEN)
17
- return model_interface
18
- except Exception as error:
19
- print(f"Error loading model: {error}")
20
- return None
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
23
 
 
24
  @app.route('/predict', methods=['POST'])
25
  def predict():
26
- # Get the JSON data from the request
27
  data = request.get_json()
28
 
29
  # Validate required fields
30
- if not data or 'prompt' not in data:
31
  return jsonify({"error": "Missing required fields"}), 400
32
 
 
33
  prompt = data['prompt']
34
  seed = data.get('seed', 1)
35
 
36
- # Make a prediction request using the loaded model
37
  try:
38
- if model_interface:
39
- # Send the request to the model interface and retrieve the result
40
- result = model_interface.fn(prompt=prompt, seed=seed, token=HF_TOKEN) # Assuming the function returns the result
41
-
42
- # Save the result to a file (if the model returns an image)
43
- result_path = "generated_image.png"
44
- result.save(result_path) # Assuming result has a save method
45
-
46
- # Send back the generated image file
47
- return send_file(result_path, mimetype='image/png')
48
  else:
49
- return jsonify({"error": "Model interface not loaded."}), 500
50
-
51
  except Exception as e:
52
  return jsonify({"error": str(e)}), 500
53
 
54
  if __name__ == '__main__':
55
- app.run(debug=True)
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from random import randint
3
  from all_models import models
 
 
 
4
  from externalmod import gr_Interface_load
5
+ import asyncio
6
  import os
7
+ from threading import RLock
8
+ from flask import Flask, request, jsonify, send_file
9
+ from flask_cors import CORS
10
+
11
+ lock = RLock()
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
14
+ # Initialize Flask app
15
  app = Flask(__name__)
16
  CORS(app) # Enable CORS for all routes
17
 
18
+ # Load models using gr_Interface_load
19
+ def load_fn(models):
20
+ global models_load
21
+ models_load = {}
22
+
23
+ for model in models:
24
+ if model not in models_load.keys():
25
+ try:
26
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
27
+ except Exception as error:
28
+ print(error)
29
+ m = gr.Interface(lambda: None, ['text'], ['image'])
30
+ models_load.update({model: m})
31
 
32
+ load_fn(models)
33
+
34
+ num_models = 6
35
+ MAX_SEED = 3999999999
36
+ default_models = models[:num_models]
37
+ inference_timeout = 600
 
 
38
 
39
+ # Inference function to generate image
40
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
41
+ kwargs = {"seed": seed}
42
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
43
+ await asyncio.sleep(0)
44
+ try:
45
+ result = await asyncio.wait_for(task, timeout=timeout)
46
+ except (Exception, asyncio.TimeoutError) as e:
47
+ print(e)
48
+ print(f"Task timed out: {model_str}")
49
+ if not task.done():
50
+ task.cancel()
51
+ result = None
52
+ if task.done() and result is not None:
53
+ with lock:
54
+ png_path = "generated_image.png"
55
+ result.save(png_path)
56
+ return png_path
57
+ return None
58
 
59
+ # Generate API function that calls the async infer function
60
+ def generate_api(model_str, prompt, seed=1):
61
+ result = asyncio.run(infer(model_str, prompt, seed))
62
+ if result:
63
+ return result # Path to the generated image
64
+ return None
65
 
66
+ # Flask route to handle predictions
67
  @app.route('/predict', methods=['POST'])
68
  def predict():
 
69
  data = request.get_json()
70
 
71
  # Validate required fields
72
+ if not data or 'prompt' not in data or 'model_str' not in data:
73
  return jsonify({"error": "Missing required fields"}), 400
74
 
75
+ model_str = data['model_str']
76
  prompt = data['prompt']
77
  seed = data.get('seed', 1)
78
 
79
+ # Generate the image using the model
80
  try:
81
+ image_path = generate_api(model_str, prompt, seed)
82
+ if image_path:
83
+ return send_file(image_path, mimetype='image/png')
 
 
 
 
 
 
 
84
  else:
85
+ return jsonify({"error": "Failed to generate image"}), 500
 
86
  except Exception as e:
87
  return jsonify({"error": str(e)}), 500
88
 
89
  if __name__ == '__main__':
90
+ # Run Flask app
91
+ app.run(debug=True)
92
+
93
+ # You can optionally launch the Gradio interface in parallel
94
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
95
+ iface.launch(show_api=True, share=True)