Geek7 commited on
Commit
7151510
·
verified ·
1 Parent(s): 4495eec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -25
app.py CHANGED
@@ -1,17 +1,23 @@
 
 
1
  import gradio as gr
2
- from random import randint
3
- from all_models import models
4
- from externalmod import gr_Interface_load
5
  import asyncio
6
  import os
7
  from threading import RLock
8
  from gradio_client import Client
 
 
9
 
 
 
 
 
10
  client = Client("Geek7/mdztxi2")
11
 
12
  lock = RLock()
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
 
 
15
  def load_fn(models):
16
  global models_load
17
  models_load = {}
@@ -22,17 +28,13 @@ def load_fn(models):
22
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
23
  except Exception as error:
24
  print(error)
25
- m = gr.Interface(lambda: None, ['text'], ['image'])
26
  models_load.update({model: m})
27
 
28
  load_fn(models)
29
 
30
- num_models = 6
31
- MAX_SEED = 3999999999
32
- default_models = models[:num_models]
33
- inference_timeout = 600
34
-
35
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
36
  kwargs = {"seed": seed}
37
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
38
  await asyncio.sleep(0)
@@ -51,19 +53,26 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
51
  return png_path
52
  return None
53
 
54
- # Expose Gradio API
55
- def generate_api(model_str, prompt, seed=1):
56
- result = asyncio.run(infer(model_str, prompt, seed))
57
- # result = client.predict(
58
- # model_str=model_str,
59
- # prompt=prompt,
60
- # seed=seed,
61
- # api_name="/predict"
62
- #)
63
- if result:
64
- return result # Path to generated image
65
- return None
 
 
 
 
 
 
 
 
66
 
67
- # Launch Gradio API without frontend
68
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file", api_name="/predict")
69
- iface.launch(show_api=True, share=True)
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS # For enabling CORS
3
  import gradio as gr
 
 
 
4
  import asyncio
5
  import os
6
  from threading import RLock
7
  from gradio_client import Client
8
+ from all_models import models # Your model import
9
+ from externalmod import gr_Interface_load # Your custom model loader
10
 
11
+ app = Flask(__name__)
12
+ CORS(app) # Enable CORS for all routes
13
+
14
+ # Gradio Client Initialization
15
  client = Client("Geek7/mdztxi2")
16
 
17
  lock = RLock()
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
 
20
+ # Model Loading Function
21
  def load_fn(models):
22
  global models_load
23
  models_load = {}
 
28
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
29
  except Exception as error:
30
  print(error)
31
+ m = gr.Interface(lambda: None, ['text'], ['image']) # Fallback
32
  models_load.update({model: m})
33
 
34
  load_fn(models)
35
 
36
+ # Async inference function to call Gradio model prediction
37
+ async def infer(model_str, prompt, seed=1, timeout=600):
 
 
 
 
38
  kwargs = {"seed": seed}
39
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
  await asyncio.sleep(0)
 
53
  return png_path
54
  return None
55
 
56
+ # API endpoint for generating an image and sending it as a file
57
+ @app.route('/generate-image', methods=['POST'])
58
+ def generate_image():
59
+ data = request.get_json()
60
+ model_str = data.get('model_str')
61
+ prompt = data.get('prompt')
62
+ seed = data.get('seed', 1)
63
+
64
+ # Validate input
65
+ if not model_str or not prompt:
66
+ return jsonify({"error": "Model string and prompt are required."}), 400
67
+
68
+ # Generate image using the async inference function
69
+ result_path = asyncio.run(infer(model_str, prompt, seed))
70
+
71
+ if result_path:
72
+ # Return the image file using send_file
73
+ return send_file(result_path, mimetype='image/png')
74
+ else:
75
+ return jsonify({"error": "Image generation failed."}), 500
76
 
77
+ if __name__ == '__main__':
78
+ app.run(debug=True)