sigyllly commited on
Commit
df3e218
·
verified ·
1 Parent(s): 5a00a52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -35
app.py CHANGED
@@ -14,8 +14,9 @@ app = Flask(__name__)
14
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
15
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
 
17
- # Global variable for caching results
18
- cache = {}
 
19
 
20
  # Function to process image and generate mask
21
  def process_image(image, prompt):
@@ -50,36 +51,7 @@ def get_masks(prompts, img, threshold):
50
 
51
  return masks
52
 
53
- # Function to extract image using positive and negative prompts
54
- def extract_image(pos_prompts, neg_prompts, img, threshold):
55
- cache_key = (pos_prompts, neg_prompts, threshold)
56
-
57
- if cache_key in cache:
58
- return cache[cache_key]
59
-
60
- positive_masks = get_masks(pos_prompts, img, 0.5)
61
- negative_masks = get_masks(neg_prompts, img, 0.5)
62
-
63
- pos_mask = np.any(np.stack(positive_masks), axis=0)
64
- neg_mask = np.any(np.stack(negative_masks), axis=0)
65
- final_mask = pos_mask & ~neg_mask
66
-
67
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
68
-
69
- # Convert final mask to base64
70
- buffered = io.BytesIO()
71
- final_mask.save(buffered, format="PNG")
72
- final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
73
-
74
- # Cache the result
75
- cache[cache_key] = {'final_mask_base64': final_mask_base64}
76
-
77
- return cache[cache_key]
78
-
79
- @app.route('/')
80
- def hello_world():
81
- return 'Hello, World!'
82
-
83
  @app.route('/api', methods=['POST'])
84
  def process_request():
85
  data = request.json
@@ -94,10 +66,22 @@ def process_request():
94
  neg_prompts = data.get('negative_prompts', '')
95
  threshold = float(data.get('threshold', 0.4))
96
 
97
- # Perform image segmentation
98
- result = extract_image(pos_prompts, neg_prompts, img, threshold)
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- return jsonify(result)
101
 
102
  # Keep the server alive using a periodic task
103
  def keep_alive():
 
14
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
15
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
 
17
+ @app.route('/')
18
+ def hello_world():
19
+ return 'Hello, World!'
20
 
21
  # Function to process image and generate mask
22
  def process_image(image, prompt):
 
51
 
52
  return masks
53
 
54
+ # Route for processing requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @app.route('/api', methods=['POST'])
56
  def process_request():
57
  data = request.json
 
66
  neg_prompts = data.get('negative_prompts', '')
67
  threshold = float(data.get('threshold', 0.4))
68
 
69
+ # Perform image segmentation without caching
70
+ positive_masks = get_masks(pos_prompts, img, 0.5)
71
+ negative_masks = get_masks(neg_prompts, img, 0.5)
72
+
73
+ pos_mask = np.any(np.stack(positive_masks), axis=0)
74
+ neg_mask = np.any(np.stack(negative_masks), axis=0)
75
+ final_mask = pos_mask & ~neg_mask
76
+
77
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
78
+
79
+ # Convert final mask to base64
80
+ buffered = io.BytesIO()
81
+ final_mask.save(buffered, format="PNG")
82
+ final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
83
 
84
+ return jsonify({'final_mask_base64': final_mask_base64})
85
 
86
  # Keep the server alive using a periodic task
87
  def keep_alive():