sigyllly commited on
Commit
d432bdc
·
verified ·
1 Parent(s): fffe605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -5,6 +5,8 @@ import torch
5
  import numpy as np
6
  import io
7
  import base64
 
 
8
 
9
  app = Flask(__name__)
10
 
@@ -12,6 +14,9 @@ app = Flask(__name__)
12
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
13
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
 
 
 
 
15
  # Function to process image and generate mask
16
  def process_image(image, prompt):
17
  inputs = processor(
@@ -47,6 +52,11 @@ def get_masks(prompts, img, threshold):
47
 
48
  # Function to extract image using positive and negative prompts
49
  def extract_image(pos_prompts, neg_prompts, img, threshold):
 
 
 
 
 
50
  positive_masks = get_masks(pos_prompts, img, 0.5)
51
  negative_masks = get_masks(neg_prompts, img, 0.5)
52
 
@@ -55,15 +65,16 @@ def extract_image(pos_prompts, neg_prompts, img, threshold):
55
  final_mask = pos_mask & ~neg_mask
56
 
57
  final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
58
- output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
59
- output_image.paste(img, mask=final_mask)
60
 
61
  # Convert final mask to base64
62
  buffered = io.BytesIO()
63
  final_mask.save(buffered, format="PNG")
64
  final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
65
 
66
- return final_mask_base64
 
 
 
67
 
68
  @app.route('/')
69
  def hello_world():
@@ -84,10 +95,21 @@ def process_request():
84
  threshold = float(data.get('threshold', 0.4))
85
 
86
  # Perform image segmentation
87
- final_mask_base64 = extract_image(pos_prompts, neg_prompts, img, threshold)
88
 
89
- return jsonify({'final_mask_base64': final_mask_base64})
 
 
 
 
 
 
90
 
91
  if __name__ == '__main__':
92
  print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
 
 
 
 
 
93
  app.run(host='0.0.0.0', port=7860, debug=True)
 
5
  import numpy as np
6
  import io
7
  import base64
8
+ import threading
9
+ import time
10
 
11
  app = Flask(__name__)
12
 
 
14
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
15
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
 
17
+ # Global variable for caching results
18
+ cache = {}
19
+
20
  # Function to process image and generate mask
21
  def process_image(image, prompt):
22
  inputs = processor(
 
52
 
53
  # Function to extract image using positive and negative prompts
54
  def extract_image(pos_prompts, neg_prompts, img, threshold):
55
+ cache_key = (pos_prompts, neg_prompts, threshold)
56
+
57
+ if cache_key in cache:
58
+ return cache[cache_key]
59
+
60
  positive_masks = get_masks(pos_prompts, img, 0.5)
61
  negative_masks = get_masks(neg_prompts, img, 0.5)
62
 
 
65
  final_mask = pos_mask & ~neg_mask
66
 
67
  final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
 
 
68
 
69
  # Convert final mask to base64
70
  buffered = io.BytesIO()
71
  final_mask.save(buffered, format="PNG")
72
  final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
73
 
74
+ # Cache the result
75
+ cache[cache_key] = {'final_mask_base64': final_mask_base64}
76
+
77
+ return cache[cache_key]
78
 
79
  @app.route('/')
80
  def hello_world():
 
95
  threshold = float(data.get('threshold', 0.4))
96
 
97
  # Perform image segmentation
98
+ result = extract_image(pos_prompts, neg_prompts, img, threshold)
99
 
100
+ return jsonify(result)
101
+
102
+ # Keep the server alive using a periodic task
103
+ def keep_alive():
104
+ while True:
105
+ time.sleep(300) # 5 minutes
106
+ requests.get('http://127.0.0.1:7860/') # Send a request to keep the server alive
107
 
108
  if __name__ == '__main__':
109
  print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/")
110
+
111
+ # Start the keep-alive thread
112
+ keep_alive_thread = threading.Thread(target=keep_alive)
113
+ keep_alive_thread.start()
114
+
115
  app.run(host='0.0.0.0', port=7860, debug=True)