Geek7 commited on
Commit
1dec5c9
·
verified ·
1 Parent(s): 0a71ede

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -12
app.py CHANGED
@@ -36,7 +36,7 @@ MAX_SEED = 3999999999
36
  default_models = models[:num_models]
37
  inference_timeout = 600
38
 
39
- # Inference function to generate image
40
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
41
  kwargs = {"seed": seed}
42
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
@@ -56,14 +56,7 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
56
  return png_path
57
  return None
58
 
59
- # Generate API function that calls the async infer function
60
- def generate_api(model_str, prompt, seed=1):
61
- result = asyncio.run(infer(model_str, prompt, seed))
62
- if result:
63
- return result # Path to the generated image
64
- return None
65
-
66
- # Flask route to handle predictions
67
  @app.route('/predict', methods=['POST'])
68
  def predict():
69
  data = request.get_json()
@@ -76,9 +69,9 @@ def predict():
76
  prompt = data['prompt']
77
  seed = data.get('seed', 1)
78
 
79
- # Generate the image using the model
80
  try:
81
- image_path = generate_api(model_str, prompt, seed)
82
  if image_path:
83
  return send_file(image_path, mimetype='image/png')
84
  else:
@@ -91,5 +84,5 @@ if __name__ == '__main__':
91
  app.run(debug=True)
92
 
93
  # You can optionally launch the Gradio interface in parallel
94
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
95
  iface.launch(show_api=True, share=True)
 
36
  default_models = models[:num_models]
37
  inference_timeout = 600
38
 
39
+ # Inference function with generate_api embedded
40
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
41
  kwargs = {"seed": seed}
42
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
 
56
  return png_path
57
  return None
58
 
59
+ # Flask API to call the generate_api function
 
 
 
 
 
 
 
60
  @app.route('/predict', methods=['POST'])
61
  def predict():
62
  data = request.get_json()
 
69
  prompt = data['prompt']
70
  seed = data.get('seed', 1)
71
 
72
+ # Make the asynchronous call to the infer function within the Flask route
73
  try:
74
+ image_path = asyncio.run(infer(model_str, prompt, seed)) # Directly call infer function here
75
  if image_path:
76
  return send_file(image_path, mimetype='image/png')
77
  else:
 
84
  app.run(debug=True)
85
 
86
  # You can optionally launch the Gradio interface in parallel
87
+ iface = gr.Interface(fn=infer, inputs=["text", "text", "number"], outputs="file")
88
  iface.launch(show_api=True, share=True)