import torch
from pathlib import Path
from transformers import CLIPProcessor, CLIPModel
from PIL import Image, ImageDraw
import pytesseract
import requests
import os 
from llm import inference, upload_image
from fastapi.responses import FileResponse, JSONResponse

import re

from io import BytesIO

cropped_images_dir = "cropped_images"
os.makedirs(cropped_images_dir, exist_ok=True)

# Load YOLO model
class YOLOModel:
    def __init__(self, model_path="yolov5s.pt"):
        """
        Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
        """
        torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
        self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)


    def predict_clip(self, image, brand_names):
        """
        Predict the most probable brand using CLIP.
        """
        inputs = self.clip_processor(
            text=brand_names,
            images=image,
            return_tensors="pt",
            padding=True
        )
        # print(f'Inputs to clip processor:{inputs}')
        outputs = self.clip_model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)  # Convert logits to probabilities
        best_idx = probs.argmax().item()
        return brand_names[best_idx], probs[0, best_idx].item()


    def predict_text(self, image):
        try:
            # Convert image to grayscale
            grayscale = image.convert('L')
            
            # Perform OCR using pytesseract
            text = pytesseract.image_to_string(grayscale)
            
            # Return the stripped text if successful
            return text.strip()
        except Exception as e:
            # Log the error for debugging purposes
            print(f"Error during text prediction: {e}")
            
            # Return an empty string if OCR fails
            return ""


    def predict(self, image_path):
        """
        Run YOLO inference on an image.

        :param image_path: Path to the input image
        :return: List of predictions with labels and bounding boxes
        """
        results = self.model(image_path)
        image = Image.open(image_path).convert("RGB")
        draw = ImageDraw.Draw(image)
        predictions = results.pandas().xyxy[0]  # Get predictions as pandas DataFrame
        print(f'YOLO predictions:\n\n{predictions}')
        
        
        output = []
        file_responses = []
        
        
        for idx, row in predictions.iterrows():
            category = row['name']
            confidence = row['confidence']
            bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]

            # Crop the detected region
            cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
            cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
            cropped_image.save(cropped_image_path, "JPEG")

            # uploading to cloud for getting URL to pass into LLM
            print(f'Uploading now to image url')
            image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
            print(f'Image URL received as{image_url}')
            # inferencing llm for possible brands
            result_llms = inference.get_name(image_url, category)

            detected_text = self.predict_text(cropped_image)
            print(f'Details:{detected_text}')
            print(f'Predicted brand: {result_llms["model"]}')
            # Draw bounding box and label on the image
            draw.rectangle(bbox, outline="red", width=3)
            draw.text(
                (bbox[0], bbox[1] - 10),
                f'{result_llms["brand"]})',
                fill="red"
            )

            cropped_image_io = BytesIO()
            cropped_image.save(cropped_image_io, format="JPEG")
            cropped_image_io.seek(0)

            # Append result
            output.append({
                "category": category,
                "bbox": bbox,
                "confidence": confidence,
                "category_llm":result_llms["brand"],
                "predicted_brand": result_llms["model"],
                # "clip_confidence": clip_confidence,
                "price":result_llms["price"],
                "details":result_llms["description"],
                "detected_text":detected_text,
                "image_path":cropped_image_path,
                "image_url":image_url,
            })

            # file_responses.append(f"/download_cropped_image/{idx}")

            valid_indices = set(range(len(predictions)))

            # Iterate over all files in the directory
            for filename in os.listdir(cropped_images_dir):
                # Check if the filename matches the pattern for cropped images
                if filename.startswith("crop_") and filename.endswith(".jpg"):
                    # Extract the index from the filename
                    try:
                        file_idx = int(filename.split("_")[1].split(".")[0])
                        if file_idx not in valid_indices:
                            # Delete the file if its index is not valid
                            file_path = os.path.join(cropped_images_dir, filename)
                            os.remove(file_path)
                            print(f"Deleted excess file: {filename}")
                    except ValueError:
                        # Skip files that don't match the pattern
                        continue

        return output
    #     return JSONResponse(
    #     content={
    #         "metadata": results,
    #         "cropped_image_urls": [
    #             f"/download_cropped_image/{idx}" for idx in range(len(file_responses))
    #         ],
    #     }
    # )
        # return {"metadata": results, "cropped_image_urls": file_responses}