cleave-fastapi / main.py
doublelotus's picture
ms
e0271c3
from flask import Flask, request, send_file, Response, jsonify
from flask_cors import CORS
import numpy as np
import io
import torch
import cv2
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import requests
app = Flask(__name__)
CORS(app)
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
print(cudaOrNah)
# Global model setup
checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
print('Setup SAM model')
@app.route('/')
def hello():
return {"hei": "Malevolent Shrine :D"}
@app.route('/health', methods=['GET'])
def health_check():
# Simple health check endpoint
return jsonify({"status": "ok"}), 200
@app.route('/get-npy')
def get_npy():
try:
print('received image from frontend')
# Get the 'img_url' from the query parameters
img_url = request.args.get('img_url', '') # Default to empty string if not provided
if not img_url:
return jsonify({"error": "No img_url provided"}), 400
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
# Convert the PIL Image to a NumPy array
image_array = np.array(raw_image)
# Since OpenCV expects BGR, convert RGB to BGR
image = image_array[:, :, ::-1]
if image is None:
raise ValueError("Image not found or unable to read.")
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
# Convert the embedding array to bytes
buffer = io.BytesIO()
np.save(buffer, image_embedding)
buffer.seek(0)
# Create a response with the correct MIME type
return send_file(buffer, mimetype='application/octet-stream', as_attachment=True, download_name='embedding.npy')
except Exception as e:
# Log the error message if needed
print(f"Error processing the image: {e}")
# Return a JSON response with the error message and a 400 Bad Request status
return jsonify({"error": "Error processing the image", "details": str(e)}), 400
# return {"hei": "gotnpy"}
if __name__ == '__main__':
app.run(debug=True)