import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import warnings
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import json
import os
import torch
from scipy.ndimage import gaussian_filter
import cv2
from method import AdaCLIP_Trainer
import numpy as np

############ Init Model
ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
ckt_path3 = 'weights/pretrained_all.pth'

# Configurations
image_size = 518
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
model = "ViT-L-14-336"
prompting_depth = 4
prompting_length = 5
prompting_type = 'SD'
prompting_branch = 'VL'
use_hsf = True
k_clusters = 20

config_path = os.path.join('./model_configs', f'{model}.json')

# Prepare model
with open(config_path, 'r') as f:
    model_configs = json.load(f)

# Set up the feature hierarchy
n_layers = model_configs['vision_cfg']['layers']
substage = n_layers // 4
features_list = [substage, substage * 2, substage * 3, substage * 4]

model = AdaCLIP_Trainer(
    backbone=model,
    feat_list=features_list,
    input_dim=model_configs['vision_cfg']['width'],
    output_dim=model_configs['embed_dim'],
    learning_rate=0.,
    device=device,
    image_size=image_size,
    prompting_depth=prompting_depth,
    prompting_length=prompting_length,
    prompting_branch=prompting_branch,
    prompting_type=prompting_type,
    use_hsf=use_hsf,
    k_clusters=k_clusters
).to(device)


def process_image(image, text, options):
    # Load the model based on selected options
    if 'MVTec AD+Colondb' in options:
        model.load(ckt_path1)
    elif 'VisA+Clinicdb' in options:
        model.load(ckt_path2)
    elif 'All' in options:
        model.load(ckt_path3)
    else:
        # Default to 'All' if no valid option is provided
        model.load(ckt_path3)
        print('Invalid option. Defaulting to All.')

    # Ensure image is in RGB mode
    image = image.convert('RGB')

    # Convert PIL image to NumPy array
    np_image = np.array(image)

    # Convert RGB to BGR for OpenCV
    np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
    np_image = cv2.resize(np_image, (image_size, image_size))
    # Preprocess the image and run the model
    img_input = model.preprocess(image).unsqueeze(0)
    img_input = img_input.to(model.device)

    with torch.no_grad():
        anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)

    # Process anomaly map
    anomaly_map = anomaly_map[0, :, :].cpu().numpy()
    anomaly_score = anomaly_score[0].cpu().numpy()
    anomaly_map = gaussian_filter(anomaly_map, sigma=4)
    anomaly_map = (anomaly_map * 255).astype(np.uint8)

    # Apply color map and blend with original image
    heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
    vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)

    # Convert OpenCV image back to PIL image for Gradio
    vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))

    return vis_map_pil, f'{anomaly_score:.3f}'

# Define examples
examples = [
    ["asset/img.png", "candle", "MVTec AD+Colondb"],
    ["asset/img2.png", "bottle", "VisA+Clinicdb"],
    ["asset/img3.png", "button", "All"],
]

# Gradio interface layout
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(label="Class Name"),
        gr.Radio(["MVTec AD+Colondb",
                  "VisA+Clinicdb",
                  "All"],
        label="Pre-trained Datasets")
    ],
    outputs=[
        gr.Image(type="pil", label="Output Image"),
        gr.Textbox(label="Anomaly Score"),
    ],
    examples=examples,
    title="AdaCLIP -- Zero-shot Anomaly Detection",
    description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
)

# Launch the demo
demo.launch()
# demo.launch(server_name="0.0.0.0", server_port=10002)