import gradio as gr
import requests
from io import BytesIO
from PIL import Image
import os
import tempfile

TOKEN = os.getenv("TOKEN1")
API_URL = os.getenv("API_URL")
token_id = 1
tokens_tried = 0
no_of_accounts = 6
model_id = os.getenv("MODEL_ID")

def get_image_from_url(url):
    """
    Fetches and returns an image from a given URL, converting to PNG if needed.
    """
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return image, url # Return the image and the URL


    except requests.exceptions.RequestException as e:
        return f"Error fetching image: {e}", None
    except Exception as e:
        return f"Error processing image: {e}", None


def generate_image(prompt, aspect_ratio, realism):
    global token_id
    global TOKEN
    global tokens_tried
    global no_of_accounts
    global model_id
    payload = {
        "id": model_id,  
        "inputs": [prompt, aspect_ratio, str(realism).lower()],
    }
    headers = {"Authorization": f"Bearer {TOKEN}"}

    try:
        response_data = requests.post(API_URL, json=payload, headers=headers).json()
        if "error" in response_data:
            if 'error 429' in response_data['error']:
                if tokens_tried < no_of_accounts:
                    token_id = (token_id + 1) % (no_of_accounts+1)
                    tokens_tried += 1
                    TOKEN = os.getenv(f"TOKEN{token_id}")
                    response_data = generate_image(prompt, aspect_ratio, realism)
                    tokens_tried = 0
                    return response_data
                return "No credits available", None
            return response_data, None
        elif "output" in response_data:
            url = response_data['output']
            image, url = get_image_from_url(url)
            return image, url  # Return the image and the URL
        else:
            return "Error: Unexpected response from server", None
    except Exception as e:
        return f"Error", None

def download_image(image_url):
    if not image_url:
         return None # Return None if image_url is empty
    try:
        response = requests.get(image_url, stream=True)
        response.raise_for_status()
        
        # Create a temporary file
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
            for chunk in response.iter_content(chunk_size=8192):
                tmp_file.write(chunk)
            temp_file_path = tmp_file.name
        return temp_file_path
    except Exception as e:
         return None
    
# Define the Gradio interface
interface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate"),
        gr.Radio(
            choices=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"], 
            label="Aspect Ratio", 
            value="16:9"  # Default value
        ),
        gr.Checkbox(label="Realism", value=False),  # Checkbox for realism (True/False)
    ],
    outputs=[
        gr.Image(type="pil"),  # Output image
        gr.File(label="Download Image", file_types = [".png", ".jpg", ".jpeg"]), # Output File to be downloaded
    ],
    title="Image Generator",
    description="Provide a prompt, select an aspect ratio, and set realism to generate an image.",
)

# Launch the interface
interface.launch()