Update app.py
Browse files
app.py
CHANGED
@@ -45,10 +45,6 @@ def get_masks(prompts, img, threshold):
|
|
45 |
|
46 |
return masks
|
47 |
|
48 |
-
@app.route('/')
|
49 |
-
def hello_world():
|
50 |
-
return 'Hello, World!'
|
51 |
-
|
52 |
# Function to extract image using positive and negative prompts
|
53 |
def extract_image(pos_prompts, neg_prompts, img, threshold):
|
54 |
positive_masks = get_masks(pos_prompts, img, 0.5)
|
@@ -62,7 +58,16 @@ def extract_image(pos_prompts, neg_prompts, img, threshold):
|
|
62 |
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
63 |
output_image.paste(img, mask=final_mask)
|
64 |
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
@app.route('/api', methods=['POST'])
|
68 |
def process_request():
|
@@ -79,16 +84,10 @@ def process_request():
|
|
79 |
threshold = float(data.get('threshold', 0.4))
|
80 |
|
81 |
# Perform image segmentation
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
buffered = io.BytesIO()
|
86 |
-
output_image.save(buffered, format="PNG")
|
87 |
-
result_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
88 |
-
|
89 |
-
return jsonify({'result_image_base64': result_image_base64})
|
90 |
|
91 |
if __name__ == '__main__':
|
92 |
print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
|
93 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
94 |
-
|
|
|
45 |
|
46 |
return masks
|
47 |
|
|
|
|
|
|
|
|
|
48 |
# Function to extract image using positive and negative prompts
|
49 |
def extract_image(pos_prompts, neg_prompts, img, threshold):
|
50 |
positive_masks = get_masks(pos_prompts, img, 0.5)
|
|
|
58 |
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
59 |
output_image.paste(img, mask=final_mask)
|
60 |
|
61 |
+
# Convert final mask to base64
|
62 |
+
buffered = io.BytesIO()
|
63 |
+
final_mask.save(buffered, format="PNG")
|
64 |
+
final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
65 |
+
|
66 |
+
return final_mask_base64
|
67 |
+
|
68 |
+
@app.route('/')
|
69 |
+
def hello_world():
|
70 |
+
return 'Hello, World!'
|
71 |
|
72 |
@app.route('/api', methods=['POST'])
|
73 |
def process_request():
|
|
|
84 |
threshold = float(data.get('threshold', 0.4))
|
85 |
|
86 |
# Perform image segmentation
|
87 |
+
final_mask_base64 = extract_image(pos_prompts, neg_prompts, img, threshold)
|
88 |
|
89 |
+
return jsonify({'final_mask_base64': final_mask_base64})
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
if __name__ == '__main__':
|
92 |
print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
|
93 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
|