import base64
import io
from fastapi import FastAPI, UploadFile, File, HTTPException
import os
from PIL import Image
from fastapi.responses import JSONResponse
from semantic_seg_model import segmentation_inference
from similarity_inference import similarity_inference
from gradio_client import Client, file
from datetime import datetime
from fastapi.middleware.cors import CORSMiddleware


app = FastAPI(docs_url="/")

allowed_origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=allowed_origins,
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE"],
    allow_headers=["*"],
)

## Initialize the pipeline
input_images_dir = "image/"
temp_processing_dir = input_images_dir + "processed/"

# Define a function to handle the POST request at `image-analyzer`
@app.post("/image-analyzer")
async def image_analyzer(image: UploadFile = File(...)):
    """
    This function takes in an image filepath and will return the PolyHaven url addresses of the 
    top k materials similar to the wall, ceiling, and floor.
    """
    try:
        # delete contents of image folder
        for filename in os.listdir(temp_processing_dir):
            file_path = os.path.join(temp_processing_dir, filename)
            try:
                os.remove(file_path)  # Remove the file
            except Exception as e:
                print(f"Failed to delete {file_path}. Reason: {e}")

        # load image
        contents = await image.read()
        if contents.startswith(b"data:image/png;base64"):
            # Remove the prefix "data:image/png;base64,"
            image_data = contents.split(b";base64,")[1]
            # Decode base64 data
            decoded_image = base64.b64decode(image_data)
            img = Image.open(io.BytesIO(decoded_image))

        else:
            img = Image.open(image.file)

        print("image loaded successfully. Processing image for segmentation and similarity inference...", datetime.now())

        # segment into components
        segmentation_inference(img, temp_processing_dir)
        print("image segmented successfully. Starting similarity inference...", datetime.now())

        # identify similar materials for each component
        matching_textures = similarity_inference(temp_processing_dir)
        print("done", datetime.now())

        # Return the urls in a JSON response
        return matching_textures
    
    except Exception as e:
        print(str(e))
        raise HTTPException(status_code=500, detail=str(e))
    

client = Client("MykolaL/StableDesign")

@app.post("/image-render")
async def image_render(prompt: str, image: UploadFile = File(...)):
    """
    Makes a prediction using the "StableDesign" model hosted on a server.

    Returns:
        The prediction result.
    """
    try:
        print(f"recieved prompt: {prompt} and image: {image}")
        image_path = os.path.join(input_images_dir, image.filename+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+".png")
        contents = await image.read()
        if contents.startswith(b"data:image/png;base64"):
            # Remove the prefix "data:image/png;base64,"
            image_data = contents.split(b";base64,")[1]
            # Decode base64 data
            decoded_image = base64.b64decode(image_data)
            img = Image.open(io.BytesIO(decoded_image))

        else:
            img = Image.open(image.file)
        # Convert image to grayscale
        grayscale_image = img.convert('L')
        # Save the processed image to the specified path
        grayscale_image.save(image_path)
        result = client.predict(
                image=file(image_path),
                text=prompt,
                num_steps=50,
                guidance_scale=10,
                seed=1111664444,
                strength=1,
                a_prompt="interior design, 4K, high resolution, photorealistic",
                n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
                img_size=768,
                api_name="/on_submit"
        )
        
        new_image_path = result
        # delete image_path
        os.remove(image_path)
        if not os.path.exists(new_image_path):
            raise HTTPException(status_code=404, detail="Image not found")
        
        # Open the image file and convert it to base64
        with open(new_image_path, "rb") as img_file:
            base64_str = base64.b64encode(img_file.read()).decode('utf-8')
            
        os.remove(new_image_path)
        return JSONResponse(content={"image": base64_str}, status_code=200)
    except Exception as e:
        print(str(e))
        raise HTTPException(status_code=500, detail=str(e))