Geek7 commited on
Commit
ba47c7a
·
verified ·
1 Parent(s): c680ed4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -1
app.py CHANGED
@@ -1,3 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @app.route('/predict', methods=['POST'])
2
  def predict():
3
  try:
@@ -41,4 +104,9 @@ def predict():
41
 
42
  except Exception as e:
43
  print(f"Error in /predict: {str(e)}")
44
- return jsonify({"error": "An error occurred during processing."}), 500
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from random import randint
3
+ from all_models import models
4
+ from externalmod import gr_Interface_load
5
+ import asyncio
6
+ import os
7
+ from threading import RLock
8
+ from gradio_client import Client
9
+ from flask import Flask, request, jsonify, send_file
10
+ from flask_cors import CORS
11
+
12
+ app = Flask(__name__)
13
+ CORS(app) # Enable CORS for all routes
14
+
15
+ lock = RLock()
16
+ HF_TOKEN = os.environ.get("HF_TOKEN")
17
+
18
+ def load_fn(models):
19
+ global models_load
20
+ models_load = {}
21
+
22
+ for model in models:
23
+ if model not in models_load.keys():
24
+ try:
25
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
26
+ except Exception as error:
27
+ print(error)
28
+ m = gr.Interface(lambda: None, ['text'], ['image'])
29
+ models_load.update({model: m})
30
+
31
+ load_fn(models)
32
+
33
+ num_models = 6
34
+ MAX_SEED = 3999999999
35
+ default_models = models[:num_models]
36
+ inference_timeout = 600
37
+
38
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
39
+ kwargs = {"seed": seed}
40
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
41
+ await asyncio.sleep(0)
42
+ try:
43
+ result = await asyncio.wait_for(task, timeout=timeout)
44
+ except (Exception, asyncio.TimeoutError) as e:
45
+ print(e)
46
+ print(f"Task timed out: {model_str}")
47
+ if not task.done():
48
+ task.cancel()
49
+ result = None
50
+ if task.done() and result is not None:
51
+ with lock:
52
+ png_path = "image.png"
53
+ result.save(png_path)
54
+ return png_path
55
+ return None
56
+
57
+ # Expose Gradio API
58
+ def generate_api(model_str, prompt, seed=1):
59
+ result = asyncio.run(infer(model_str, prompt, seed))
60
+ if result:
61
+ return result # Path to generated image
62
+ return None
63
+
64
  @app.route('/predict', methods=['POST'])
65
  def predict():
66
  try:
 
104
 
105
  except Exception as e:
106
  print(f"Error in /predict: {str(e)}")
107
+ return jsonify({"error": "An error occurred during processing."}), 500
108
+
109
+
110
+ # Launch Gradio API without frontend
111
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
112
+ iface.launch(show_api=True, share=True)