import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import numpy as np
import supervision as sv
import albumentations as A
import cv2
from transformers import AutoConfig
import yaml

# Set Streamlit page configuration for a wide layout
st.set_page_config(layout="wide")

# Custom CSS for better layout and mobile responsiveness
st.markdown("""
    <style>
        .main {
            max-width: 1200px;  /* Max width for content */
            margin: 0 auto;
        }
        .block-container {
            padding-top: 2rem;
            padding-bottom: 2rem;
            padding-left: 3rem;
            padding-right: 3rem;
        }
        .title {
            font-size: 2.5rem;
            text-align: center;
            color: #FF6347;
        }
        .subheader {
            font-size: 1.5rem;
            margin-bottom: 20px;
        }
        .btn {
            font-size: 1.1rem;
            padding: 10px 20px;
            background-color: #FF6347;
            color: white;
            border-radius: 5px;
            border: none;
            cursor: pointer;
        }
        .btn:hover {
            background-color: #FF4500;
        }
        .column-spacing {
            display: flex;
            justify-content: space-between;
        }
        .col-half {
            width: 48%;
        }
        .col-full {
            width: 100%;
        }
        .instructions {
            padding: 20px;
            background-color: #f9f9f9;
            border-radius: 8px;
            box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
        }
    </style>
""", unsafe_allow_html=True)

# Load Model and Processor
@st.cache_resource
def load_model():
    REVISION = 'refs/pr/6'
    # MODEL_NAME = "RioJune/AD-KD-MICCAI25"
    MODEL_NAME = 'Anonymous-AC/AD-KD'
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config_model = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
    config_model.vision_config.model_type = "davit"

    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, config=config_model).to(DEVICE)

    BASE_PROCESSOR = "microsoft/Florence-2-base-ft"
    processor = AutoProcessor.from_pretrained(BASE_PROCESSOR, trust_remote_code=True)
    processor.image_processor.size = 512
    processor.image_processor.crop_size = 512

    return model, processor, DEVICE

model, processor, DEVICE = load_model()

# Load Definitions
@st.cache_resource
def load_definitions():
    vindr_path = 'configs/vindr_definition.yaml'
    padchest_path = 'configs/padchest_definition.yaml'
    prompt_path = 'examples/prompt.yaml'

    with open(vindr_path, 'r') as file:
        vindr_definitions = yaml.safe_load(file)
    with open(padchest_path, 'r') as file:
        padchest_definitions = yaml.safe_load(file)
    with open(prompt_path, 'r') as file:
        prompt_definitions = yaml.safe_load(file)

    return vindr_definitions, padchest_definitions, prompt_definitions

vindr_definitions, padchest_definitions, prompt_definitions = load_definitions()

dataset_options = {"Vindr": vindr_definitions, "PadChest": padchest_definitions}

def load_example_images():
    return list(prompt_definitions.keys())

example_images = load_example_images()

def apply_transform(image, size_mode=512):
    pad_resize_transform = A.Compose([
        A.LongestMaxSize(max_size=size_mode, interpolation=cv2.INTER_AREA),
        A.PadIfNeeded(min_height=size_mode, min_width=size_mode, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
        A.Resize(height=512, width=512, interpolation=cv2.INTER_AREA),
    ])
    image_np = np.array(image)
    transformed = pad_resize_transform(image=image_np)
    return transformed["image"]

# Streamlit UI with Colorful Title and Emojis
st.markdown("<h1 class='title'>🩺 Enhancing Abnormality Grounding for Vision Language Models with Knowledge Descriptions πŸš€</h1>", unsafe_allow_html=True)
st.markdown(
    "<p style='text-align: center; font-size: 18px;'>Welcome to a simple demo of our work! πŸŽ‰ Choose an example or upload your own image to get started! πŸ‘‡</p>", 
    unsafe_allow_html=True
)

# Display Example Images First
st.subheader("πŸŒ„ Example Images")
selected_example = st.selectbox("Choose an example", example_images)
image = Image.open(selected_example).convert("RGB")
example_diseases = prompt_definitions.get(selected_example, [])
st.write("**Associated Diseases:**", ", ".join(example_diseases))

# Layout for Original Image and Instructions
col1, col2 = st.columns([1, 2])

# Left column for original image
with col1:
    st.image(image, caption=f"Original Example Image: {selected_example}", width=400)

# Right column for Instructions and Run Inference Button
with col2:
    st.subheader("βš™οΈ Instructions to Get Started:")
    st.write("""
        - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results.
        - **Choose an Example**: πŸŒ„ Select an example image from the dataset to view its associated diseases.
        - **Upload Your Own Image**: πŸ“€ Upload an image of your choice to analyze it for diseases.
        - **Select Dataset**: πŸ“š Choose between available datasets (Vindr or PadChest) for disease information.
        - **Select Disease**: 🦠 Pick the disease to be analyzed from the list of diseases in the selected dataset.
    """)

    st.subheader("⚠️ Warning:")
    st.write("""
    - **🚫 Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only.
    - This demo is intended for **πŸ”¬ research purposes only** and should **❌ not be used for medical diagnoses**.
    - The model’s responses may contain **<span style='color:#dc3545; font-weight:bold;'>πŸ€– hallucinations or incorrect information</span>**. 
    - Always consult a **<span style='color:#dc3545; font-weight:bold;'>πŸ‘¨β€βš•οΈ medical professional</span>** for accurate diagnosis and advice.
""", unsafe_allow_html=True)


    st.markdown("</div>", unsafe_allow_html=True)

# Run Inference Button
if st.button("Run Inference on Example", key="example"):
    if image is None:
        st.error("❌ Please select an example image first.")
    else:
        # Use the selected example's disease and definition for inference
        disease_choice = example_diseases[0] if example_diseases else ""
        definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))

        # Generate the prompt for the model
        det_obj = f"{disease_choice} means {definition}."
        st.write(f"**Definition:** {definition}")
        prompt = f"Locate the phrases in the caption: {det_obj}."
        prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}"

        # Prepare the image and input
        np_image = np.array(image)
        inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)

        with st.spinner("Processing... ⏳"):
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=1024,
                num_beams=3,
                output_scores=True,  # Make sure we get the scores/logits
                return_dict_in_generate=True  # Ensures you get both sequences and scores in the output
                )
            

            # Ensure transition_scores is properly extracted
            transition_scores = model.compute_transition_scores(
                outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
            )

            # Get the generated token IDs (ignoring the input tokens part)
            generated_ids = outputs.sequences
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

            # Get input length
            input_length = inputs.input_ids.shape[1]
            generated_tokens = outputs.sequences

            # Calculate output length (number of generated tokens)
            output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)

            # Get length penalty
            length_penalty = model.generation_config.length_penalty

            # Calculate total score for the generated sentence
            reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)

            # Convert log-probability to probability (0-1 range)
            probabilities = np.exp(reconstructed_scores.cpu().numpy())

            # Streamlit UI to display the result
            st.markdown(f"**🎯 Probability of the Results:** <span style='color:#28a745; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True)
                

            predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])

            detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])

            # Annotate the image with bounding boxes and labels
            bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
            label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
            image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection)
            image_with_predictions = label_annotator.annotate(image_with_predictions, detection)
            annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8))

            # Display the original and result images side by side
            col1, col2 = st.columns([1, 1])

            with col1:
                st.image(image, caption=f"Original Image: {selected_example}", width=400)

            with col2:
                st.image(annotated_image, caption="Inference Results πŸ–ΌοΈ", width=400)

            # Display the generated text
            st.write("**Generated Text:**", generated_text)

# Upload Image section
st.subheader("πŸ“€ Upload Your Own Image")

col1, col2 = st.columns([1, 1])
with col1:
    dataset_choice = st.selectbox("Select Dataset πŸ“š", options=list(dataset_options.keys()))
    disease_options = list(dataset_options[dataset_choice].keys())
with col2:
    disease_choice = st.selectbox("Select Disease 🦠", options=disease_options)

uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])


col1, col2 = st.columns([1, 2])

with col1:
    # Handle file upload
    if uploaded_file:
        image = Image.open(uploaded_file).convert("RGB")
        image = apply_transform(image)  # Ensure the uploaded image is transformed correctly
        st.image(image, caption="Uploaded Image", width=400)

        # Let user select dataset and disease dynamically
        disease_choice = disease_choice if disease_choice else example_diseases[0]

        # Get Definition Priority: Dataset -> User Input
        definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, ""))
        if not definition:
            definition = st.text_input("Enter Definition Manually πŸ“", value="")

with col2:
    # Instructions and warnings
    st.subheader("βš™οΈ Instructions to Get Started:")
    st.write("""
        - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results.
        - **Choose an Example**: πŸŒ„ Select an example image from the dataset to view its associated diseases.
        - **Upload Your Own Image**: πŸ“€ Upload an image of your choice to analyze it for diseases.
        - **Select Dataset**: πŸ“š Choose between available datasets (Vindr or PadChest) for disease information.
        - **Select Disease**: 🦠 Pick the disease to be analyzed from the list of diseases in the selected dataset.
    """)

    st.subheader("⚠️ Warning:")
    st.write("""
    - **🚫 Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only.
    - This demo is intended for **πŸ”¬ research purposes only** and should **❌ not be used for medical diagnoses**.
    - The model’s responses may contain **<span style='color:#dc3545; font-weight:bold;'>πŸ€– hallucinations or incorrect information</span>**. 
    - Always consult a **<span style='color:#dc3545; font-weight:bold;'>πŸ‘¨β€βš•οΈ medical professional</span>** for accurate diagnosis and advice.
""", unsafe_allow_html=True)

# Run inference after upload
if st.button("Run Inference πŸƒβ€β™‚οΈ"):
    if image is None:
        st.error("❌ Please upload an image or select an example.")
    else:
        det_obj = f"{disease_choice} means {definition}."
        st.write(f"**Definition:** {definition}")

        # Construct Prompt with Disease Definition
        prompt = f"Locate the phrases in the caption: {det_obj}."
        prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}"

        np_image = np.array(image)
        inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE)

        with st.spinner("Processing... ⏳"):
            # generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
            # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
            
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=1024,
                num_beams=3,
                output_scores=True,  # Make sure we get the scores/logits
                return_dict_in_generate=True  # Ensures you get both sequences and scores in the output
                )

            transition_scores = model.compute_transition_scores(
                outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
            )

            # Get the generated token IDs (ignoring the input tokens part)
            generated_ids = outputs.sequences
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

            # Get input length
            input_length = inputs.input_ids.shape[1]

            # Extract generated tokens (ignoring the input tokens)
            # generated_tokens = outputs.sequences[:, input_length:]
            generated_tokens = outputs.sequences

            # Calculate output length (number of generated tokens)
            output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)

            # Get length penalty
            length_penalty = model.generation_config.length_penalty

            # Calculate total score for the generated sentence
            reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)

            # Convert log-probability to probability (0-1 range)
            probabilities = np.exp(reconstructed_scores.cpu().numpy())

            # Streamlit UI to display the result
            
            # st.write(f"**Probability of the Results (0-1):** {probabilities[0]:.4f}")
            st.markdown(f"**🎯 Probability of the Results:** <span style='color:green; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True)



            predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2])

            detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2])

            bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
            label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
            image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection)
            image_with_predictions = label_annotator.annotate(image_with_predictions, detection)
            annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8))

            # Create two columns to display the original and the results side by side
            col1, col2 = st.columns([1, 1])

            # Left column for original image
            with col1:
                st.image(image, caption="Uploaded Image", width=400)

            # Right column for result image
            with col2:
                st.image(annotated_image, caption="Inference Results πŸ–ΌοΈ", width=400)

            # Display the generated text
            st.write("**Generated Text:**", generated_text)