sigyllly commited on
Commit
63596f0
·
verified ·
1 Parent(s): 90674f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -139
app.py CHANGED
@@ -1,146 +1,104 @@
1
- from flask import Flask, request, jsonify, render_template_string
2
- import subprocess
 
 
 
 
 
 
 
3
 
4
  app = Flask(__name__)
5
 
6
- # Route to the homepage with embedded HTML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @app.route('/')
8
  def index():
9
- html = """
10
- <!DOCTYPE html>
11
- <html lang="en">
12
- <head>
13
- <meta charset="UTF-8">
14
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
15
- <title>Administrator Terminal</title>
16
- <style>
17
- body {
18
- font-family: Arial, sans-serif;
19
- margin: 0;
20
- padding: 0;
21
- background-color: #f4f4f4;
22
- }
23
- .container {
24
- width: 80%;
25
- margin: 50px auto;
26
- padding: 20px;
27
- background-color: white;
28
- box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
29
- border-radius: 10px;
30
- }
31
- #terminal {
32
- width: 100%;
33
- height: 300px;
34
- background-color: #000;
35
- color: #0f0;
36
- font-family: "Courier New", Courier, monospace;
37
- padding: 10px;
38
- overflow-y: scroll;
39
- white-space: pre-wrap;
40
- border-radius: 5px;
41
- margin-bottom: 10px;
42
- }
43
- #command {
44
- width: 100%;
45
- padding: 10px;
46
- font-size: 1em;
47
- margin-bottom: 10px;
48
- border: 1px solid #ccc;
49
- border-radius: 5px;
50
- }
51
- button {
52
- padding: 10px 20px;
53
- font-size: 1em;
54
- border: none;
55
- border-radius: 5px;
56
- background-color: #28a745;
57
- color: white;
58
- cursor: pointer;
59
- }
60
- button:hover {
61
- background-color: #218838;
62
- }
63
- </style>
64
- </head>
65
- <body>
66
- <div class="container">
67
- <h2>Administrator Terminal</h2>
68
- <div id="terminal"></div>
69
- <input type="text" id="command" placeholder="Enter command..." autofocus>
70
- <button onclick="sendCommand()">Execute</button>
71
- </div>
72
-
73
- <script>
74
- // Function to send command to the Flask server
75
- function sendCommand() {
76
- const command = document.getElementById('command').value;
77
-
78
- if (command.trim() === '') {
79
- alert('Please enter a command');
80
- return;
81
- }
82
-
83
- fetch('/execute', {
84
- method: 'POST',
85
- headers: {
86
- 'Content-Type': 'application/x-www-form-urlencoded'
87
- },
88
- body: new URLSearchParams({
89
- 'command': command
90
- })
91
- })
92
- .then(response => response.json())
93
- .then(data => {
94
- const terminal = document.getElementById('terminal');
95
- if (data.stdout) {
96
- terminal.innerHTML += '> ' + command + '\\n' + data.stdout + '\\n';
97
- }
98
- if (data.stderr) {
99
- terminal.innerHTML += '> ' + command + '\\n' + data.stderr + '\\n';
100
- }
101
- if (data.error) {
102
- terminal.innerHTML += '> ' + command + '\\n' + data.error + '\\n';
103
- }
104
- document.getElementById('command').value = '';
105
- terminal.scrollTop = terminal.scrollHeight; // Scroll to the bottom
106
- })
107
- .catch(error => {
108
- console.error('Error:', error);
109
- });
110
- }
111
-
112
- // Allow pressing Enter to send the command
113
- document.getElementById('command').addEventListener('keydown', function (e) {
114
- if (e.key === 'Enter') {
115
- sendCommand();
116
- }
117
- });
118
- </script>
119
- </body>
120
- </html>
121
- """
122
- return render_template_string(html)
123
-
124
- # Route to execute a command
125
- @app.route('/execute', methods=['POST'])
126
- def execute():
127
- try:
128
- # Get the command from the request
129
- command = request.form['command']
130
-
131
- # Execute the command and capture the output
132
- result = subprocess.run(command, shell=True, capture_output=True, text=True)
133
-
134
- # Return the output (stdout and stderr)
135
- return jsonify({
136
- 'stdout': result.stdout,
137
- 'stderr': result.stderr
138
- })
139
-
140
- except Exception as e:
141
- return jsonify({
142
- 'error': str(e)
143
- })
144
 
145
  if __name__ == '__main__':
146
- app.run(host='0.0.0.0', port=7860, debug=True)
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from PIL import Image
3
+ import base64
4
+ from io import BytesIO
5
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import cv2
10
 
11
  app = Flask(__name__)
12
 
13
+ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
14
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
15
+
16
+ def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
17
+ inputs = processor(
18
+ text=prompt, images=image, padding="max_length", return_tensors="pt"
19
+ )
20
+
21
+ # predict
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ preds = outputs.logits
25
+
26
+ pred = torch.sigmoid(preds)
27
+ mat = pred.cpu().numpy()
28
+ mask = Image.fromarray(np.uint8(mat * 255), "L")
29
+ mask = mask.convert("RGB")
30
+ mask = mask.resize(image.size)
31
+ mask = np.array(mask)[:, :, 0]
32
+
33
+ # normalize the mask
34
+ mask_min = mask.min()
35
+ mask_max = mask.max()
36
+ mask = (mask - mask_min) / (mask_max - mask_min)
37
+
38
+ # threshold the mask
39
+ bmask = mask > threshold
40
+ # zero out values below the threshold
41
+ mask[mask < threshold] = 0
42
+
43
+ fig, ax = plt.subplots()
44
+ ax.imshow(image)
45
+ ax.imshow(mask, alpha=alpha_value, cmap="jet")
46
+
47
+ if draw_rectangles:
48
+ contours, hierarchy = cv2.findContours(
49
+ bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
50
+ )
51
+ for contour in contours:
52
+ x, y, w, h = cv2.boundingRect(contour)
53
+ rect = plt.Rectangle(
54
+ (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
55
+ )
56
+ ax.add_patch(rect)
57
+
58
+ ax.axis("off")
59
+ plt.tight_layout()
60
+
61
+ bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
62
+ output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
63
+ output_image.paste(image, mask=bmask)
64
+
65
+ # Convert mask to base64
66
+ buffered_mask = BytesIO()
67
+ bmask.save(buffered_mask, format="PNG")
68
+ result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
69
+
70
+ # Convert output image to base64
71
+ buffered_output = BytesIO()
72
+ output_image.save(buffered_output, format="PNG")
73
+ result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')
74
+
75
+ return fig, result_mask, result_output
76
+
77
+ # Existing process_image function, copy it here
78
+ # ...
79
+
80
  @app.route('/')
81
  def index():
82
+ return render_template('index.html')
83
+
84
+ @app.route('/api/mask_image', methods=['POST'])
85
+ def mask_image_api():
86
+ data = request.get_json()
87
+
88
+ base64_image = data.get('base64_image', '')
89
+ prompt = data.get('prompt', '')
90
+ threshold = data.get('threshold', 0.4)
91
+ alpha_value = data.get('alpha_value', 0.5)
92
+ draw_rectangles = data.get('draw_rectangles', False)
93
+
94
+ # Decode base64 image
95
+ image_data = base64.b64decode(base64_image.split(',')[1])
96
+ image = Image.open(BytesIO(image_data))
97
+
98
+ # Process the image
99
+ _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
100
+
101
+ return jsonify({'result_mask': result_mask, 'result_output': result_output})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if __name__ == '__main__':
104
+ app.run(host='0.0.0.0', port=7860, debug=True)