from flask import Flask, request, send_file, Response, jsonify
from flask_cors import CORS
import numpy as np
import io
import torch
import cv2
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import requests

app = Flask(__name__)
CORS(app)

cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu"
print(cudaOrNah)

# Global model setup
checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
print('Setup SAM model')

@app.route('/')
def hello():
    return {"hei": "Malevolent Shrine :D"}

@app.route('/health', methods=['GET'])
def health_check():
    # Simple health check endpoint
    return jsonify({"status": "ok"}), 200

@app.route('/get-npy')
def get_npy():
    try:   
        print('received image from frontend')     
        # Get the 'img_url' from the query parameters
        img_url = request.args.get('img_url', '') # Default to empty string if not provided
        
        if not img_url:
            return jsonify({"error": "No img_url provided"}), 400
        
        raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
        # Convert the PIL Image to a NumPy array
        image_array = np.array(raw_image)
        # Since OpenCV expects BGR, convert RGB to BGR
        image = image_array[:, :, ::-1]

        if image is None:
            raise ValueError("Image not found or unable to read.")
        
        predictor.set_image(image)
        image_embedding = predictor.get_image_embedding().cpu().numpy()
        
        # Convert the embedding array to bytes
        buffer = io.BytesIO()
        np.save(buffer, image_embedding)
        buffer.seek(0)
        
        # Create a response with the correct MIME type
        return send_file(buffer, mimetype='application/octet-stream', as_attachment=True, download_name='embedding.npy')
    except Exception as e:
        # Log the error message if needed
        print(f"Error processing the image: {e}")
        # Return a JSON response with the error message and a 400 Bad Request status
        return jsonify({"error": "Error processing the image", "details": str(e)}), 400
    # return {"hei": "gotnpy"}

if __name__ == '__main__':
    app.run(debug=True)