Geek7 commited on
Commit
ac39d6f
·
verified ·
1 Parent(s): 885a3c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -42
app.py CHANGED
@@ -1,44 +1,60 @@
1
- from flask import Flask, request, jsonify, send_file
2
- from gradio_client import Client
 
 
 
3
  import os
4
- from all_models import models # Import your models
5
-
6
- app = Flask(__name__)
7
-
8
- # Initialize the Gradio client
9
- client = Client("Geek7/mdztxi2")
10
-
11
- @app.route('/generate-image', methods=['POST'])
12
- def generate_image():
13
- data = request.get_json()
14
-
15
- # Check for required fields in the request data
16
- model_str = data.get('model_str')
17
- prompt = data.get('prompt')
18
- seed = data.get('seed')
19
-
20
- if not model_str or not prompt:
21
- return jsonify({"error": "Model string and prompt are required."}), 400
22
-
23
- # Make a prediction request
 
 
 
 
 
 
 
 
 
24
  try:
25
- # Making sure to capture the actual result returned
26
- result = client.predict(
27
- model_str=model_str,
28
- prompt=prompt,
29
- seed=seed,
30
- api_name="/predict"
31
- )
32
-
33
- # If the result is a file path, send it; if it's image bytes, handle accordingly
34
- if isinstance(result, str) and os.path.exists(result): # If the result is a path
35
- return send_file(result, mimetype='image/png') # Send the generated image file
36
- elif isinstance(result, bytes): # If the result is in bytes
37
- return send_file(io.BytesIO(result), mimetype='image/png') # Return image bytes as a file
38
- else:
39
- return jsonify({"error": "Image generation failed, result not valid."}), 500
40
- except Exception as e:
41
- return jsonify({"error": str(e)}), 500
42
-
43
- if __name__ == '__main__':
44
- app.run(debug=True)
 
 
 
 
 
1
+ import gradio as gr
2
+ from random import randint
3
+ from all_models import models
4
+ from externalmod import gr_Interface_load
5
+ import asyncio
6
  import os
7
+ from threading import RLock
8
+
9
+ lock = RLock()
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
+
12
+ def load_fn(models):
13
+ global models_load
14
+ models_load = {}
15
+
16
+ for model in models:
17
+ if model not in models_load.keys():
18
+ try:
19
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
20
+ except Exception as error:
21
+ print(error)
22
+ m = gr.Interface(lambda: None, ['text'], ['image'])
23
+ models_load.update({model: m})
24
+
25
+ load_fn(models)
26
+
27
+ num_models = 6
28
+ MAX_SEED = 3999999999
29
+ default_models = models[:num_models]
30
+ inference_timeout = 600
31
+
32
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
+ kwargs = {"seed": seed}
34
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
+ await asyncio.sleep(0)
36
  try:
37
+ result = await asyncio.wait_for(task, timeout=timeout)
38
+ except (Exception, asyncio.TimeoutError) as e:
39
+ print(e)
40
+ print(f"Task timed out: {model_str}")
41
+ if not task.done():
42
+ task.cancel()
43
+ result = None
44
+ if task.done() and result is not None:
45
+ with lock:
46
+ png_path = "image.png"
47
+ result.save(png_path)
48
+ return png_path
49
+ return None
50
+
51
+ # Expose Gradio API
52
+ def generate_api(model_str, prompt, seed=1):
53
+ result = asyncio.run(infer(model_str, prompt, seed))
54
+ if result:
55
+ return result # Path to generated image
56
+ return None
57
+
58
+ # Launch Gradio API without frontend
59
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
+ iface.launch(show_api=True, share=True)