Spaces:
Running
on
T4
Running
on
T4
from flask import Flask, request, send_file, Response, jsonify | |
from flask_cors import CORS | |
import numpy as np | |
import io | |
import torch | |
import cv2 | |
from segment_anything import sam_model_registry, SamPredictor | |
from PIL import Image | |
import requests | |
app = Flask(__name__) | |
CORS(app) | |
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu" | |
print(cudaOrNah) | |
# Global model setup | |
checkpoint = "sam_vit_l_0b3195.pth" | |
model_type = "vit_l" | |
sam = sam_model_registry[model_type](checkpoint=checkpoint) | |
sam.to(device='cuda') | |
predictor = SamPredictor(sam) | |
print('Setup SAM model') | |
def hello(): | |
return {"hei": "Malevolent Shrine :D"} | |
def health_check(): | |
# Simple health check endpoint | |
return jsonify({"status": "ok"}), 200 | |
def get_npy(): | |
try: | |
print('received image from frontend') | |
# Get the 'img_url' from the query parameters | |
img_url = request.args.get('img_url', '') # Default to empty string if not provided | |
if not img_url: | |
return jsonify({"error": "No img_url provided"}), 400 | |
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") | |
# Convert the PIL Image to a NumPy array | |
image_array = np.array(raw_image) | |
# Since OpenCV expects BGR, convert RGB to BGR | |
image = image_array[:, :, ::-1] | |
if image is None: | |
raise ValueError("Image not found or unable to read.") | |
predictor.set_image(image) | |
image_embedding = predictor.get_image_embedding().cpu().numpy() | |
# Convert the embedding array to bytes | |
buffer = io.BytesIO() | |
np.save(buffer, image_embedding) | |
buffer.seek(0) | |
# Create a response with the correct MIME type | |
return send_file(buffer, mimetype='application/octet-stream', as_attachment=True, download_name='embedding.npy') | |
except Exception as e: | |
# Log the error message if needed | |
print(f"Error processing the image: {e}") | |
# Return a JSON response with the error message and a 400 Bad Request status | |
return jsonify({"error": "Error processing the image", "details": str(e)}), 400 | |
# return {"hei": "gotnpy"} | |
if __name__ == '__main__': | |
app.run(debug=True) |