import gradio as gr
import requests
import os
from PIL import Image
from io import BytesIO
from tqdm import tqdm
import time

repo = "artificialguybr/TshirtDesignRedmond-V2"

def infer(color_prompt, dress_type_prompt, design_prompt):
    # Improved prompt for higher accuracy
    prompt = (
        f"A high-quality digital image of a {color_prompt} {dress_type_prompt}, "
        f"featuring a {design_prompt} printed in sharp detail printedon the {dress_type_prompt},"
        f"facing front, hanging on he plain wall"
        f"The fabric has realistic texture,"
        f"smooth folds, and accurate lighting. The design is perfectly aligned, with natural shadows "
        f"and highlights, creating a photorealistic look."
    )
    
    print("Generating image with prompt:", prompt)
    api_url = f"https://api-inference.huggingface.co/models/{repo}"
    
    headers = {}  # If API token needed, add here
    
    payload = {
        "inputs": prompt,
        "parameters": {
            # Optimized negative prompt
            "negative_prompt": "low quality, artifacts, distorted, blurry, overexposed, underexposed, unrealistic texture, poor lighting, misaligned print, plastic-like fabric, grainy, washed-out colors, 3D render, cartoon, digital art, watermark, bad anatomy, malformed, cluttered design",
            "num_inference_steps": 30,
            "scheduler": "EulerAncestralDiscreteScheduler"  # Faster & more accurate scheduler
        },
    }

    error_count = 0
    pbar = tqdm(total=None, desc="Loading model")
    
    while True:
        print("Sending request to API...")
        response = requests.post(api_url, headers=headers, json=payload)
        print("API response status code:", response.status_code)
        
        if response.status_code == 200:
            print("Image generation successful!")
            return Image.open(BytesIO(response.content))
        elif response.status_code == 503:
            time.sleep(1)
            pbar.update(1)
        elif response.status_code == 500 and error_count < 5:
            time.sleep(1)
            error_count += 1
        else:
            print("API Error:", response.status_code)
            raise Exception(f"API Error: {response.status_code}")

# Gradio Interface
iface = gr.Interface(
    fn=infer,
    inputs=[
        gr.Textbox(lines=1, placeholder="Color Prompt"),
        gr.Textbox(lines=1, placeholder="Dress Type Prompt"),
        gr.Textbox(lines=2, placeholder="Design Prompt"),
    ],
    outputs="image",
    title="Make your Brand",
    description="Generation of clothes",
    examples=[["Red", "T-shirt", "Simple design"]]
)

print("Launching Gradio interface...")
iface.launch()