import torch
import gradio as gr
import numpy as np
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
from PIL import Image
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)
initial_archi = 'biggan-deep-128' #@param ['biggan-deep-128', 'biggan-deep-256', 'biggan-deep-512'] {allow-input: true}
initial_class = 'dog'

gan_model = BigGAN.from_pretrained(initial_archi)

def generate_images (initial_archi, initial_class, batch_size):
    truncation = 0.4
    class_vector = one_hot_from_names(initial_class, batch_size=batch_size)
    noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size)

    # All in tensors
    noise_vector = torch.from_numpy(noise_vector)
    class_vector = torch.from_numpy(class_vector)

    # If you have a GPU, put everything on cuda
    #noise_vector = noise_vector.to('cuda')
    #class_vector = class_vector.to('cuda')
    #gan_model.to('cuda')

    # Generate an image
    with torch.no_grad():
        output = gan_model(noise_vector, class_vector, truncation)

    # If you have a GPU put back on CPU
    output = output.to('cpu')
    save_as_images(output)
    return output
    
def convert_to_images(obj):
    """ Convert an output tensor from BigGAN in a list of images.
        Params:
            obj: tensor or numpy array of shape (batch_size, channels, height, width)
        Output:
            list of Pillow Images of size (height, width)
    """
    try:
        import PIL
    except ImportError:
        raise ImportError("Please install Pillow to use images: pip install Pillow")

    if not isinstance(obj, np.ndarray):
        obj = obj.detach().numpy()

    obj = obj.transpose((0, 2, 3, 1))
    obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)

    img = []
    for i, out in enumerate(obj):
        out_array = np.asarray(np.uint8(out), dtype=np.uint8)
        img.append(PIL.Image.fromarray(out_array))
    return img
    
def inference(initial_archi, initial_class):
    output = generate_images (initial_archi, initial_class, 1)
    PIL_output = convert_to_images(output)
    return PIL_output[0]
  


title = "BigGAN"
description = "BigGAN using various architecture models to generate images."
article="Coming soon"

examples = [
  ["biggan-deep-128", "dog"],
  ["biggan-deep-256", 'dog'],
  ["biggan-deep-512", 'dog']
]

gr.Interface(inference, 
             inputs=[gr.inputs.Dropdown(["biggan-deep-128", "biggan-deep-256", "biggan-deep-512"]), "text"], 
             outputs= [gr.outputs.Image(type="pil",label="output")], 
             examples=examples, 
             title=title, 
             description=description, 
             article=article).launch( debug=True)