from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
from htbuilder.units import percent, px
from htbuilder.funcs import rgba, rgb
import streamlit as st
import os
import sys
import argparse
import clip
import numpy as np
from PIL import Image
#from dalle.models import Dalle
#from dalle.utils.utils import set_seed, clip_score
import streamlit.components.v1 as components
import torch 
from IPython.display import display
import random

def link(link, text, **style):
    return a(_href=link, _target="_blank", style=styles(**style))(text)

def layout(*args):

    style = """
    <style>
      # MainMenu {visibility: hidden;}
      footer {visibility: hidden;}
     .stApp { bottom: 125px; }
    </style>
    """

    style_div = styles(
        position="fixed",
        left=0,
        bottom=0,
        margin=px(0, 0, 0, 0),
        width=percent(100),
        color="black",
        text_align="center",
        height="auto",
        opacity=1
    )

    style_hr = styles(
        display="block",
        margin=px(8, 8, "auto", "auto"),
        border_style="inset",
        border_width=px(2)
    )

    body = p()
    foot = div(
        style=style_div
    )(
        hr(
            style=style_hr
        ),
        body
    )

    st.markdown(style, unsafe_allow_html=True)

    for arg in args:
        if isinstance(arg, str):
            body(arg)

        elif isinstance(arg, HtmlElement):
            body(arg)

    st.markdown(str(foot), unsafe_allow_html=True)


def footer():
    myargs = [
        "This app uses the ",
        link("https://github.com/kuprel/min-dalle", "min(DALL·E)"),
        " port of ",
        link("https://github.com/borisdayma/dalle-mini", "DALL·E mini"),
        br(),
        "Created by ",
        link("https://jonathanmalott.com", "Jonathan Malott"),
        br(),
        link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems Grand Challenge"),
        ", The University of Texas at Austin.",
        " Advised by Dr. Junfeng Jiao.",
        br(),
        br(),
    ]
    layout(*myargs)


    components.html(
    """
    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-SB6NJ9DQS7"></script>
    <script>
      window.dataLayer = window.dataLayer || [];
      function gtag(){dataLayer.push(arguments);}
      gtag('js', new Date());

      gtag('config', 'G-SB6NJ9DQS7');
    </script>
    """
    )


from min_dalle import MinDalle

def generate2(prompt,crazy,k):

    
    mm = MinDalle(
        models_root='./pretrained',
        dtype=torch.float32,
        device='cpu',
        is_mega=False, 
        is_reusable=True
    )

    # Sampling
    newPrompt = prompt
    if("architecture" not in prompt.lower() ):
        newPrompt += " architecture"

    image = mm.generate_image(
        text=newPrompt,
        seed=np.random.randint(0,10000),
        grid_size=1,
        is_seamless=False,
        temperature=crazy,
        top_k=k,#2128,
        supercondition_factor=32,
        is_verbose=False
    )

    item = {}
    item['prompt'] = prompt
    item['crazy'] = crazy
    item['k'] = k
    item['image'] = image
    st.session_state.results.append(item)

model = False
def generate(prompt,crazy,k):
    global model

    device = 'cpu'
    if(model == False):
        model = Dalle.from_pretrained('minDALL-E/1.3B')  # This will automatically download the pretrained model.
        model.to(device=device)
    
    num_candidates = 1

    images = []
    
    set_seed(np.random.randint(0,10000))

    # Sampling
    newPrompt = prompt
    if("architecture" not in prompt.lower() ):
        newPrompt += " architecture"

    images = model.sampling(prompt=newPrompt,
                            top_k=256,
                            top_p=None,
                            softmax_temperature=crazy,
                            num_candidates=num_candidates,
                            device=device).cpu().numpy()
    images = np.transpose(images, (0, 2, 3, 1))

    # CLIP Re-ranking
    model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
    model_clip.to(device=device)
    rank = clip_score(prompt=newPrompt,
                      images=images,
                      model_clip=model_clip,
                      preprocess_clip=preprocess_clip,
                      device=device)

    result = images[rank]

    item = {}
    item['prompt'] = prompt
    item['crazy'] = crazy
    item['k'] = 20
    item['image'] = Image.fromarray((result*255).astype(np.uint8))
    st.session_state.results.append(item)


def drawGrid():
    master = {}
    
    for r in st.session_state.results[::-1]:
        _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
        if(_txt not in master):
            master[_txt] = [r]
        else:
            master[_txt].append(r)


    for i in st.session_state.images:
        im = st.empty()


    placeholder = st.empty()
    with placeholder.container():
        
        for m in master:
            
            txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
            st.subheader(txt)
            col1, col2, col3 = st.columns(3)  

            for ix, item in enumerate(master[m]):
                if ix % 3 == 0: 
                    with col1:
                        st.session_state.images.append(st.image(item["image"]))  
                if ix % 3 == 1:
                    with col2:
                        st.session_state.images.append(st.image(item["image"]))
                if ix % 3 == 2:
                    with col3:
                        st.session_state.images.append(st.image(item["image"]))