Geek7 commited on
Commit
74fabd0
1 Parent(s): 3806071

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +24 -60
myapp.py CHANGED
@@ -1,61 +1,28 @@
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
- import asyncio
4
- import tempfile
5
  import os
6
- from threading import RLock
7
  from huggingface_hub import InferenceClient
8
- from all_models import models # Importing models from all_models
9
  from io import BytesIO # For converting image to bytes
10
 
11
- myapp = Flask(__name__)
12
- CORS(myapp) # Enable CORS for all routes
 
13
 
14
- lock = RLock()
15
- HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
 
16
 
17
- inference_timeout = 600 # Set timeout for inference
18
-
19
- @myapp.route('/')
20
- def home():
21
- return "Welcome to the Image Background Remover!"
22
-
23
- # Function to dynamically load models from the "models" list
24
- def get_model_from_name(model_name):
25
- return model_name if model_name in models else None
26
-
27
- # Asynchronous function to perform inference
28
- async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
29
- task = asyncio.create_task(
30
- asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
31
- )
32
- await asyncio.sleep(0)
33
  try:
34
- result = await asyncio.wait_for(task, timeout=timeout)
35
- except (Exception, asyncio.TimeoutError) as e:
36
- print(e)
37
- print(f"Task timed out for model: {model}")
38
- if not task.done():
39
- task.cancel()
40
- result = None
41
-
42
- if task.done() and result is not None:
43
- with lock:
44
- # Convert image result to bytes
45
- image_bytes = BytesIO()
46
- result.save(image_bytes, format='PNG') # Save the image to a BytesIO object
47
- image_bytes.seek(0) # Go to the start of the byte stream
48
-
49
- # Save the result image as a temporary file
50
- temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
51
- with open(temp_image.name, "wb") as f:
52
- f.write(image_bytes.read()) # Write the bytes to the temp file
53
-
54
- return temp_image.name # Return the path to the saved image
55
- return None
56
 
57
  # Flask route for the API endpoint
58
- @myapp.route('/generate_api', methods=['POST'])
59
  def generate_api():
60
  data = request.get_json()
61
 
@@ -67,19 +34,16 @@ def generate_api():
67
  if not prompt:
68
  return jsonify({"error": "Prompt is required"}), 400
69
 
70
- # Get the model from all_models
71
- model = get_model_from_name(model_name)
72
- if not model:
73
- return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
74
-
75
  try:
76
- # Create a generic InferenceClient for the model
77
- client = InferenceClient(token=HF_TOKEN)
78
-
79
- # Call the async inference function
80
- result_path = asyncio.run(infer(client, prompt, seed, model=model))
81
- if result_path:
82
- return send_file(result_path, mimetype='image/png') # Send back the generated image file
 
 
83
  else:
84
  return jsonify({"error": "Failed to generate image"}), 500
85
  except Exception as e:
@@ -88,4 +52,4 @@ def generate_api():
88
 
89
  # Add this block to make sure your app runs when called
90
  if __name__ == "__main__":
91
- myapp.run(host='0.0.0.0', port=7860) # Run directly if needed for testing
 
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
 
 
3
  import os
 
4
  from huggingface_hub import InferenceClient
 
5
  from io import BytesIO # For converting image to bytes
6
 
7
+ # Initialize the Flask app
8
+ app = Flask(__name__)
9
+ CORS(app) # Enable CORS for all routes
10
 
11
+ # Initialize the InferenceClient with your Hugging Face token
12
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Ensure to set your Hugging Face token in the environment
13
+ client = InferenceClient(token=HF_TOKEN)
14
 
15
+ # Function to generate an image from a prompt
16
+ def generate_image(prompt, seed=1, model="prompthero/openjourney-v4"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
+ result_image = client.text_to_image(prompt=prompt, seed=seed, model=model)
19
+ return result_image
20
+ except Exception as e:
21
+ print(f"Error generating image: {str(e)}")
22
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Flask route for the API endpoint
25
+ @app.route('/generate_image', methods=['POST'])
26
  def generate_api():
27
  data = request.get_json()
28
 
 
34
  if not prompt:
35
  return jsonify({"error": "Prompt is required"}), 400
36
 
 
 
 
 
 
37
  try:
38
+ # Call the generate_image function
39
+ image = generate_image(prompt, seed, model_name)
40
+
41
+ if image:
42
+ # Save the image to a BytesIO object to send as response
43
+ image_bytes = BytesIO()
44
+ image.save(image_bytes, format='PNG')
45
+ image_bytes.seek(0) # Go to the start of the byte stream
46
+ return send_file(image_bytes, mimetype='image/png', as_attachment=True, download_name='generated_image.png')
47
  else:
48
  return jsonify({"error": "Failed to generate image"}), 500
49
  except Exception as e:
 
52
 
53
  # Add this block to make sure your app runs when called
54
  if __name__ == "__main__":
55
+ app.run(host='0.0.0.0', port=7860) # Run directly if needed for testing