Geek7 commited on
Commit
5ee6e7d
·
verified ·
1 Parent(s): 4a761f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from huggingface_hub import InferenceClient
5
  from io import BytesIO
6
  from PIL import Image
 
7
 
8
  # Initialize the Flask app
9
  app = Flask(__name__)
@@ -13,14 +14,20 @@ CORS(app) # Enable CORS for all routes
13
  HF_TOKEN = os.environ.get("HF_TOKEN") # Ensure to set your Hugging Face token in the environment
14
  client = InferenceClient(token=HF_TOKEN)
15
 
 
 
 
16
  @app.route('/')
17
  def home():
18
  return "Welcome to the Image Background Remover!"
19
 
20
-
21
  # Function to generate an image from a text prompt
22
  def generate_image(prompt, negative_prompt=None, height=512, width=512, model="stabilityai/stable-diffusion-2-1", num_inference_steps=50, guidance_scale=7.5, seed=None):
23
  try:
 
 
 
 
24
  # Generate the image using Hugging Face's inference API with additional parameters
25
  image = client.text_to_image(
26
  prompt=prompt,
@@ -40,6 +47,9 @@ def generate_image(prompt, negative_prompt=None, height=512, width=512, model="s
40
  # Flask route for the API endpoint to generate an image
41
  @app.route('/generate_image', methods=['POST'])
42
  def generate_api():
 
 
 
43
  data = request.get_json()
44
 
45
  # Extract required fields from the request
@@ -59,6 +69,10 @@ def generate_api():
59
  # Call the generate_image function with the provided parameters
60
  image = generate_image(prompt, negative_prompt, height, width, model_name, num_inference_steps, guidance_scale, seed)
61
 
 
 
 
 
62
  if image:
63
  # Save the image to a BytesIO object
64
  img_byte_arr = BytesIO()
@@ -78,6 +92,13 @@ def generate_api():
78
  print(f"Error in generate_api: {str(e)}") # Log the error
79
  return jsonify({"error": str(e)}), 500
80
 
 
 
 
 
 
 
 
81
  # Add this block to make sure your app runs when called
82
  if __name__ == "__main__":
83
  app.run(host='0.0.0.0', port=7860) # Run directly if needed for testing
 
4
  from huggingface_hub import InferenceClient
5
  from io import BytesIO
6
  from PIL import Image
7
+ import threading
8
 
9
  # Initialize the Flask app
10
  app = Flask(__name__)
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN") # Ensure to set your Hugging Face token in the environment
15
  client = InferenceClient(token=HF_TOKEN)
16
 
17
+ # Global variable to manage generation rejection state
18
+ generation_rejected = threading.Event()
19
+
20
  @app.route('/')
21
  def home():
22
  return "Welcome to the Image Background Remover!"
23
 
 
24
  # Function to generate an image from a text prompt
25
  def generate_image(prompt, negative_prompt=None, height=512, width=512, model="stabilityai/stable-diffusion-2-1", num_inference_steps=50, guidance_scale=7.5, seed=None):
26
  try:
27
+ # Check if generation has been marked for rejection
28
+ if generation_rejected.is_set():
29
+ return "rejected"
30
+
31
  # Generate the image using Hugging Face's inference API with additional parameters
32
  image = client.text_to_image(
33
  prompt=prompt,
 
47
  # Flask route for the API endpoint to generate an image
48
  @app.route('/generate_image', methods=['POST'])
49
  def generate_api():
50
+ global generation_rejected
51
+ generation_rejected.clear() # Reset rejection state at the start of a new request
52
+
53
  data = request.get_json()
54
 
55
  # Extract required fields from the request
 
69
  # Call the generate_image function with the provided parameters
70
  image = generate_image(prompt, negative_prompt, height, width, model_name, num_inference_steps, guidance_scale, seed)
71
 
72
+ # Check if the response was rejected
73
+ if image == "rejected":
74
+ return jsonify({"error": "Image generation was rejected."}), 400
75
+
76
  if image:
77
  # Save the image to a BytesIO object
78
  img_byte_arr = BytesIO()
 
92
  print(f"Error in generate_api: {str(e)}") # Log the error
93
  return jsonify({"error": str(e)}), 500
94
 
95
+ # Add a new endpoint to reject the image generation
96
+ @app.route('/reject', methods=['POST'])
97
+ def reject_generation():
98
+ global generation_rejected
99
+ generation_rejected.set() # Set the flag to indicate the request should be rejected
100
+ return jsonify({"message": "Image generation has been marked for rejection."})
101
+
102
  # Add this block to make sure your app runs when called
103
  if __name__ == "__main__":
104
  app.run(host='0.0.0.0', port=7860) # Run directly if needed for testing