from transformers import SamModel, SamConfig, SamProcessor
import torch
import numpy as np
import matplotlib.pyplot as plt
import app
import os
import json
from PIL import Image

def pred(src):
    # -- cache
    cache_dir = "/code/cache"

    # -- load model configuration
    MODEL_FILE = "sam_model.pth"
    model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)

    model = SamModel(config=model_config)
    model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu')))

    with open("sam-config.json", "r") as f: # modified config json file
        modified_config_dict = json.load(f)

    processor = SamProcessor.from_pretrained("facebook/sam-vit-base", 
                                            **modified_config_dict, 
                                            cache_dir=cache_dir)
    
    # -- process image
    image = Image.open(src)
    rgbim = image.convert("RGB")
    new_image = np.array(rgbim)
    print()
    print("image shape:",new_image.shape)

    inputs = processor(new_image, return_tensors="pt")
    model.eval()

    # forward pass
    print("predicting...")
    with torch.no_grad():
        outputs = model(pixel_values=inputs["pixel_values"], 
                        multimask_output=False)
    
    # apply sigmoid
    print("apply sigmoid...")
    pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

    # convert soft mask to hard mask
    PROBABILITY_THRES = 0.30
    pred_prob = pred_prob.cpu().numpy().squeeze()
    pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
    
    return pred_prob, pred_prediction