Geek7 commited on
Commit
55b35d2
·
verified ·
1 Parent(s): 6657ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -31
app.py CHANGED
@@ -1,27 +1,39 @@
 
 
 
1
  import gradio as gr
2
- from random import randint
3
- from all_models import models
4
- from externalmod import gr_Interface_load
5
- import asyncio
6
  import os
 
7
  from threading import RLock
 
 
8
 
9
- lock = RLock()
 
 
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
11
 
 
12
  def load_fn(models):
13
  global models_load
14
  models_load = {}
15
-
16
  for model in models:
17
  if model not in models_load.keys():
18
  try:
 
19
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
20
  except Exception as error:
21
  print(error)
22
- m = gr.Interface(lambda: None, ['text'], ['image'])
23
  models_load.update({model: m})
24
 
 
25
  load_fn(models)
26
 
27
  num_models = 6
@@ -30,31 +42,33 @@ default_models = models[:num_models]
30
  inference_timeout = 600
31
 
32
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
- kwargs = {"seed": seed}
34
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
- await asyncio.sleep(0)
36
  try:
37
- result = await asyncio.wait_for(task, timeout=timeout)
38
- except (Exception, asyncio.TimeoutError) as e:
39
- print(e)
40
- print(f"Task timed out: {model_str}")
41
- if not task.done():
42
- task.cancel()
43
- result = None
44
- if task.done() and result is not None:
45
  with lock:
46
  png_path = "image.png"
47
- result.save(png_path)
48
  return png_path
49
- return None
50
-
51
- # Expose Gradio API
52
- def generate_api(model_str, prompt, seed=1):
53
- result = asyncio.run(infer(model_str, prompt, seed,api_name="/predict" ))
54
- if result:
55
- return result # Path to generated image
56
- return None
57
-
58
- # Launch Gradio API without frontend
59
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
- iface.launch(show_api=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS # Import CORS
3
+ from gradio_client import Client
4
  import gradio as gr
 
 
 
 
5
  import os
6
+ import asyncio
7
  from threading import RLock
8
+ from all_models import models # Import your models
9
+ from externalmod import gr_Interface_load # Import the model loading function
10
 
11
+ app = Flask(__name__)
12
+ CORS(app) # Enable CORS for all routes
13
+
14
+ # Environment variable for Hugging Face token
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
+ lock = RLock()
17
+
18
+ # Initialize the Gradio client
19
+ client = Client("Geek7/mdztxi2")
20
 
21
+ # Load models using gr_Interface_load
22
  def load_fn(models):
23
  global models_load
24
  models_load = {}
25
+
26
  for model in models:
27
  if model not in models_load.keys():
28
  try:
29
+ # Use your custom model loading function
30
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
31
  except Exception as error:
32
  print(error)
33
+ m = gr.Interface(lambda: None, ['text'], ['image']) # Fallback to a dummy interface
34
  models_load.update({model: m})
35
 
36
+ # Load the models
37
  load_fn(models)
38
 
39
  num_models = 6
 
42
  inference_timeout = 600
43
 
44
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
 
 
 
45
  try:
46
+ # Use the gradio_client to make the API call
47
+ result = await asyncio.to_thread(client.predict, model_str=model_str, prompt=prompt, seed=seed, api_name="/predict")
 
 
 
 
 
 
48
  with lock:
49
  png_path = "image.png"
50
+ result.save(png_path) # Save the image to a file
51
  return png_path
52
+ except Exception as e:
53
+ print(f"Error during inference: {e}")
54
+ return None
55
+
56
+ @app.route('/generate-image', methods=['POST'])
57
+ def generate_image():
58
+ data = request.get_json()
59
+ model_str = data['model_str']
60
+ prompt = data['prompt']
61
+ seed = data['seed']
62
+
63
+ # Make a prediction request
64
+ try:
65
+ result_path = asyncio.run(infer(model_str, prompt, seed))
66
+ if result_path:
67
+ return send_file(result_path, mimetype='image/png') # Send the generated image file
68
+ else:
69
+ return jsonify({"error": "Image generation failed."}), 500
70
+ except Exception as e:
71
+ return jsonify({"error": str(e)}), 500
72
+
73
+ if __name__ == '__main__':
74
+ app.run(debug=True)