doublelotus commited on
Commit
f546076
·
1 Parent(s): c0ca689
Files changed (2) hide show
  1. main.py +43 -26
  2. requirements.txt +8 -8
main.py CHANGED
@@ -1,9 +1,25 @@
1
  from flask import Flask, request, send_file, Response, jsonify
2
  from flask_cors import CORS
 
 
 
 
 
 
 
3
 
4
  app = Flask(__name__)
5
  CORS(app)
6
 
 
 
 
 
 
 
 
 
 
7
  @app.route('/')
8
  def hello():
9
  return {"hei": "you succesfully deployed"}
@@ -15,37 +31,38 @@ def health_check():
15
 
16
  @app.route('/get-npy')
17
  def get_npy():
18
- # # Get the 'img_url' from the query parameters
19
- # img_url = request.args.get('img_url', '') # Default to empty string if not provided
 
20
 
21
- # if not img_url:
22
- # return jsonify({"error": "No img_url provided"}), 400
23
 
24
- # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
25
- # # Convert the PIL Image to a NumPy array
26
- # image_array = np.array(raw_image)
27
- # # Since OpenCV expects BGR, convert RGB to BGR
28
- # image = image_array[:, :, ::-1]
29
-
30
- # if image is None:
31
- # raise ValueError("Image not found or unable to read.")
32
 
33
- # predictor.set_image(image)
34
- # image_embedding = predictor.get_image_embedding().cpu().numpy()
35
 
36
- # # Convert the embedding array to bytes
37
- # buffer = io.BytesIO()
38
- # np.save(buffer, image_embedding)
39
- # buffer.seek(0)
40
 
41
- # # Create a response with the correct MIME type
42
- # return send_file(buffer, mimetype='application/octet-stream', as_attachment=True, download_name='embedding.npy')
43
- # except Exception as e:
44
- # # Log the error message if needed
45
- # print(f"Error processing the image: {e}")
46
- # # Return a JSON response with the error message and a 400 Bad Request status
47
- # return jsonify({"error": "Error processing the image", "details": str(e)}), 400
48
- return {"hei": "gotnpy"}
49
 
50
  if __name__ == '__main__':
51
  app.run(debug=True)
 
1
  from flask import Flask, request, send_file, Response, jsonify
2
  from flask_cors import CORS
3
+ import numpy as np
4
+ import io
5
+ import torch
6
+ import cv2
7
+ from segment_anything import sam_model_registry, SamPredictor
8
+ from PIL import Image
9
+ import requests
10
 
11
  app = Flask(__name__)
12
  CORS(app)
13
 
14
+ print('cuda available:' + torch.cuda.is_available())
15
+
16
+ # Global model setup
17
+ checkpoint = "sam_vit_l_0b3195.pth"
18
+ model_type = "vit_l"
19
+ sam = sam_model_registry[model_type](checkpoint=checkpoint)
20
+ sam.to(device='cuda')
21
+ predictor = SamPredictor(sam)
22
+
23
  @app.route('/')
24
  def hello():
25
  return {"hei": "you succesfully deployed"}
 
31
 
32
  @app.route('/get-npy')
33
  def get_npy():
34
+ try:
35
+ # Get the 'img_url' from the query parameters
36
+ img_url = request.args.get('img_url', '') # Default to empty string if not provided
37
 
38
+ if not img_url:
39
+ return jsonify({"error": "No img_url provided"}), 400
40
 
41
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
42
+ # Convert the PIL Image to a NumPy array
43
+ image_array = np.array(raw_image)
44
+ # Since OpenCV expects BGR, convert RGB to BGR
45
+ image = image_array[:, :, ::-1]
46
+
47
+ if image is None:
48
+ raise ValueError("Image not found or unable to read.")
49
 
50
+ predictor.set_image(image)
51
+ image_embedding = predictor.get_image_embedding().cpu().numpy()
52
 
53
+ # Convert the embedding array to bytes
54
+ buffer = io.BytesIO()
55
+ np.save(buffer, image_embedding)
56
+ buffer.seek(0)
57
 
58
+ # Create a response with the correct MIME type
59
+ return send_file(buffer, mimetype='application/octet-stream', as_attachment=True, download_name='embedding.npy')
60
+ except Exception as e:
61
+ # Log the error message if needed
62
+ print(f"Error processing the image: {e}")
63
+ # Return a JSON response with the error message and a 400 Bad Request status
64
+ return jsonify({"error": "Error processing the image", "details": str(e)}), 400
65
+ # return {"hei": "gotnpy"}
66
 
67
  if __name__ == '__main__':
68
  app.run(debug=True)
requirements.txt CHANGED
@@ -1,11 +1,11 @@
1
  flask
2
  gunicorn
3
  flask-cors
4
- # numpy
5
- # opencv-python
6
- # Pillow
7
- # requests
8
- # git+https://github.com/facebookresearch/segment-anything.git
9
- # --extra-index-url https://download.pytorch.org/whl/cu113
10
- # torch
11
- # torchvision
 
1
  flask
2
  gunicorn
3
  flask-cors
4
+ numpy
5
+ opencv-python
6
+ Pillow
7
+ requests
8
+ git+https://github.com/facebookresearch/segment-anything.git
9
+ --extra-index-url https://download.pytorch.org/whl/cu113
10
+ torch
11
+ torchvision