from diffusers import DiffusionPipeline
from typing import List, Optional, Tuple, Union
import torch
import gradio as gr
css="""
#input-panel{
align-items:center;
justify-content:center
    
}
"""
modelnames=[
    "ahmedfaiyaz/OkkhorDiffusion",
    "ahmedfaiyaz/OkkhorDiffusion-CMATERdb",
    "ahmedfaiyaz/OkkhorDiffusion-Ekush"
]
current_model="ahmedfaiyaz/OkkhorDiffusion"
#updated username
pipeline = DiffusionPipeline.from_pretrained(current_model,custom_pipeline="ahmedfaiyaz/OkkhorDiffusion",embedding=torch.float16)
character_mappings = {
    'অ': 0,
    'আ': 1,
    'ই': 2,
    'ঈ': 3,
    'উ': 4,
    'ঊ': 5,
    'ঋ': 6,
    'এ': 7,
    'ঐ': 8,
    'ও': 9,
    'ঔ': 10,
    'ক': 11,
    'খ': 12,
    'গ': 13,
    'ঘ': 14,
    'ঙ': 15,
    'চ': 16,
    'ছ': 17,
    'জ': 18,
    'ঝ': 19,
    'ঞ': 20,
    'ট': 21,
    'ঠ': 22,
    'ড': 23,
    'ঢ': 24,
    'ণ': 25,
    'ত': 26,
    'থ': 27,
    'দ': 28,
    'ধ': 29,
    'ন': 30,
    'প': 31,
    'ফ': 32,
    'ব': 33,
    'ভ': 34,
    'ম': 35,
    'য': 36,
    'র': 37,
    'ল': 38,
    'শ': 39,
    'ষ': 40,
    'স': 41,
    'হ': 42,
    'ড়': 43,
    'ঢ়': 44,
    'য়': 45,
    'ৎ': 46,
    'ং': 47,
    'ঃ': 48,
    'ঁ': 49,
    '০': 50,
    '১': 51,
    '২': 52,
    '৩': 53,
    '৪': 54,
    '৫': 55,
    '৬': 56,
    '৭': 57,
    '৮': 58,
    '৯': 59,
    'ক্ষ(ksa)': 60,
    'ব্দ(bda)': 61,
    'ঙ্গ': 62,
    'স্ক': 63,
    'স্ফ': 64,
    'স্থ': 65,
    'চ্ছ': 66,
    'ক্ত': 67,
    'স্ন': 68,
    'ষ্ণ': 69,
    'ম্প': 70,
    'হ্ম': 71,
    'প্ত': 72,
    'ম্ব': 73,
    'ন্ড': 74,
    'দ্ভ': 75,
    'ত্থ': 76,
    'ষ্ঠ': 77,
    'ল্প': 78,
    'ষ্প': 79,
    'ন্দ': 80,
    'ন্ধ': 81,
    'ম্ম': 82,
    'ন্ঠ': 83,
}
ekush_mappings = {'অ': 1,
 'আ': 2,
 'ই': 3,
 'ঈ': 4,
 'উ': 5,
 'ঊ': 6,
 'ঋ': 7,
 'এ': 8,
 'ঐ': 9,
 'ও': 10,
 'ঔ': 11,
 'ক': 12,
 'খ': 13,
 'গ': 14,
 'ঘ': 15,
 'ঙ': 16,
 'চ': 17,
 'ছ': 18,
 'জ': 19,
 'ঝ': 20,
 'ঞ': 21,
 'ট': 22,
 'ঠ': 23,
 'ড': 24,
 'ঢ': 25,
 'ণ': 26,
 'ত': 27,
 'থ': 28,
 'দ': 29,
 'ধ': 30,
 'ন': 31,
 'প': 32,
 'ফ': 33,
 'ব': 34,
 'ভ': 35,
 'ম': 36,
 'য': 37,
 'র': 38,
 'ল': 39,
 'শ': 40,
 'ষ': 41,
 'স': 42,
 'হ': 43,
 'ড়': 44,
 'ঢ়': 45,
 'য়': 46,
 'ৎ': 47,
 'ং': 48,
 'ঃ': 49,
 'ঁ': 50,
 'ব্দ': 51,
 'ঙ্গ': 52,
 'স্ক': 53,
 'স্ফ': 54,
 'চ্ছ': 55,
 'স্থ': 56,
 'ক্ত': 57,
 'স্ন': 58,
 'ষ্ণ': 59,
 'ম্প': 60,
 'প্ত': 61,
 'ম্ব': 62,
 'ত্থ': 63,
 'দ্ভ': 64,
 'ষ্ঠ': 65,
 'ল্প': 66,
 'ষ্প': 67,
 'ন্দ': 68,
 'ন্ধ': 69,
 'স্ম': 70,
 'ণ্ঠ': 71,
 'স্ত': 72,
 'ষ্ট': 73,
 'ন্ম': 74,
 'ত্ত': 75,
 'ঙ্খ': 76,
 'ত্ন': 77,
 'ন্ড': 78,
 'জ্ঞ': 79,
 'ড্ড': 80,
 'ক্ষ': 81,
 'দ্ব': 82,
 'চ্চ': 83,
 'ক্র': 84,
 'দ্দ': 85,
 'জ্জ': 86,
 'ক্ক': 87,
 'ন্ত': 88,
 'ক্ট': 89,
 'ঞ্চ': 90,
 'ট্ট': 91,
 'শ্চ': 92,
 'ক্স': 93,
 'জ্ব': 94,
 'ঞ্জ': 95,
 'দ্ধ': 96,
 'ন্ন': 97,
 'ঘ্ন': 98,
 'ক্ল': 99,
 'হ্ন': 100,
 '০': 101,
 '১': 102,
 '২': 103,
 '৩': 104,
 '৪': 105,
 '৫': 106,
 '৬': 107,
 '৭': 108,
 '৮': 109,
 '৯': 110}

cmaterdb_mappings={
'প্র': 0,
 'ঙ্গ': 1,
 'ক্ষ': 2,
 'ত্র': 3,
 'ন্দ': 4,
 'চ্ছ': 5,
 'ন্ত': 6,
 'ন্দ্র': 7,
 'স্ত': 8,
 'ন্তু': 9,
 'গ্র': 10,
 'স্থ': 11,
 'স্ট': 12,
 'ম্ব': 13,
 'স্ব': 14,
 'ত্ত': 15,
 'ক্ত': 16,
 'ন্ট': 17,
 'ল্প': 18,
 'ষ্ট': 19,
 'ন্ত্র': 20,
 'ক্র': 21,
 'ন্ন': 22,
 'দ্ধ': 23,
 'ন্ধ': 24,
 'ঙ্ক': 25,
 'ন্ড': 26,
 'ফ্র': 27,
 'ম্প': 28,
 'স্ক': 29,
 'জ্ঞ': 30,
 'ক্ট': 31,
 'শ্চ': 32,
 'ট্র': 33,
 'ত্ব': 34,
 'ল্ল': 35,
 'ব্র': 36,
 'ঞ্চ': 37,
 'ণ্ড': 38,
 'ক্স': 39,
 'শ্র': 40,
 'দ্র': 41,
 'স্প': 42,
 'ঞ্জ': 43,
 'ন্স': 44,
 'ম্ভ': 45,
 'শ্ব': 46,
 'ব্দ': 47,
 'শ্ন': 48,
 'প্প': 49,
 'ব্ল': 50,
 'প্ত': 51,
 'ক্ল': 52,
 'ষ্ট্র': 53,
 'দ্ব': 54,
 'ট্ট': 55,
 'গ্ল': 56,
 'ল্ট': 57,
 'ষ্ঠ': 58,
 'স্ত্র': 59,
 'প্ল': 60,
 'চ্চ': 61,
 'স্ম': 62,
 'দ্দ': 63,
 'গ্ন': 64,
 'জ্ব': 65,
 'ষ্ক': 66,
 'ত্ম': 67,
 'ড্র': 68,
 'ম্ম': 69,
 'ণ্ট': 70,
 'ম্প্র': 71,
 'প্ন': 72,
 'ন্ম': 73,
 'স্ফ': 74,
 'ল্দ': 75,
 'ত্ত্ব': 76,
 'জ্জ': 77,
 'ক্ষ্ম': 78,
 'ষ্ণ': 79,
 'ন্ব': 80,
 'ক্ক': 81,
 'ন্থ': 82,
 'ড্ড': 83,
 'ব্ব': 84,
 'ন্ট্র': 85,
 'ণ্ঠ': 86,
 'প্ট': 87,
 'স্তু': 88,
 'ধ্ব': 89,
 'হ্ণ': 90,
 'ভ্র': 91,
 'ল্ক': 92,
 'স্ল': 93,
 'হ্ন': 94,
 'ত্ন': 95,
 'ষ্ক্র': 96,
 'ঘ্র': 97,
 'দ্ভ': 98,
 'শ্ল': 99,
 'ব্ধ': 100,
 'ষ্ম': 101,
 'স্ক্র': 102,
 'ড়্গ': 103,
 'জ্জ্ব': 104,
 'শ্ম': 105,
 'দ্ম': 106,
 'ক্ব': 107,
 'ম্র': 108,
 'গ্ধ': 109,
 'ব্জ': 110,
 'স্ন': 111,
 'ন্দ্ব': 112,
 'হ্ম': 113,
 'ঙ্ঘ': 114,
 'খ্র': 115,
 'ত্থ': 116,
 'ল্ব': 117,
 'ম্ন': 118,
 'ঘ্ন': 119,
 'গ্গ': 120,
 'ক্ষ্ণ': 121,
 'গ্রু': 122,
 'চ্ছ্ব': 123,
 'ণ্ণ': 124,
 'ল্ম': 125,
 'স্র': 126,
 'ম্ল': 127,
 'ষ্প্র': 128,
 'ঞ্ঝ': 129,
 'স্প্র': 130,
 'ম্ভ্র': 131,
 'ষ্প': 132,
 'ঙ্খ': 133,
 'জ্র': 134,
 'গ্ব': 135,
 'থ্ব': 136,
 'ণ্ব': 137,
 'হ্ব': 138,
 'দ্দ্ব': 139,
 'দ্ঘ': 140,
 'ধ্র': 141,
 'হ্ল': 142,
 'গ্ম': 143,
 'ল্গ': 144,
 'স্খ': 145,
 'থ্র': 146,
 'ন্ধ্র': 147,
 'ফ্ল': 148,
 'ঙ্ক্ষ': 149,
 'ণ্ম': 150,
 'ঞ্ছ': 151,
 'ম্ফ': 152,
 'হ্র': 153,
 'প্রু': 154,
 'ত্রু': 155,
 'ভ্ল': 156,
 'শ্রু': 157,
 'দ্রু': 158,
 'ঙ্ম': 159,
 'ক্ম': 160,
 'দ্গ': 161,
 'ন্ড্র': 162,
 'ট্ব': 163,
 'চ্ঞ': 164,
 'প্স': 165,
 'ল্ড': 166,
 'ষ্ফ': 167,
 'শ্ছ': 168,
 'জ্ঝ': 169,
 'স্ট্র': 170,
 'অ': 171,
 'আ': 172,
 'ই': 173,
 'ঈ': 174,
 'উ': 175,
 'ঊ': 176,
 'ঋ': 177,
 'এ': 178,
 'ঐ': 179,
 'ও': 180,
 'ঔ': 181,
 'ক': 182,
 'খ': 183,
 'গ': 184,
 'ঘ': 185,
 'ঙ': 186,
 'চ': 187,
 'ছ': 188,
 'জ': 189,
 'ঝ': 190,
 'ঞ': 191,
 'ট': 192,
 'ঠ': 193,
 'ড': 194,
 'ঢ': 195,
 'ণ': 196,
 'ত': 197,
 'থ': 198,
 'দ': 199,
 'ধ': 200,
 'ন': 201,
 'প': 202,
 'ফ': 203,
 'ব': 204,
 'ভ': 205,
 'ম': 206,
 'য': 207,
 'র': 208,
 'ল': 209,
 'শ': 210,
 'ষ': 211,
 'স': 212,
 'হ': 213,
 'ড়': 214,
 'ঢ়': 215,
 'য়': 216,
 'ৎ': 217,
 'ং': 218,
 'ঃ': 219,
 'ঁ': 220}
character_mappings_model_wise={
    "ahmedfaiyaz/OkkhorDiffusion":character_mappings,
    "ahmedfaiyaz/OkkhorDiffusion-CMATERdb":cmaterdb_mappings,
    "ahmedfaiyaz/OkkhorDiffusion-Ekush":ekush_mappings
}


def generate(modelname:str,input_text:str,batch_size:int,inference_steps:int):
    batch_size=int(batch_size)
    inference_steps=int(inference_steps)
    print(f"Generating image with label:{character_mappings_model_wise[current_model][input_text]} batch size:{batch_size}")
    label=int(character_mappings_model_wise[current_model][input_text])
    pipeline.embedding=torch.tensor([label],device="cpu") #testing zero gpu
    generate_image=pipeline(batch_size=batch_size,num_inference_steps=inference_steps).images
    return generate_image


def switch_pipeline(modelname:str):
    global pipeline
    pipeline = DiffusionPipeline.from_pretrained(modelname,custom_pipeline="ahmedfaiyaz/OkkhorDiffusion",embedding=torch.int16)
    
    
    global current_model
    current_model=modelname
    return f"""
            <div style="text-align: center; margin: 0 auto;">Selected model <a href="https://huggingface.co/{modelname}">{modelname}</a></div>
            """,gr.update(choices=character_mappings_model_wise[modelname])


with gr.Blocks(css=css,elem_id="panel") as od_app:
    with gr.Column(min_width=100):
        text=gr.HTML("""
         <div style="text-align: center; margin: 0 auto;">
              <div style="display: inline-flex;align-items: center;gap: 0.8rem;font-size: 1.75rem;">
                <h1> Okkhor Diffusion </h1>
               </div>
        </div>
""")
    #input panel 
    choosen_model=gr.HTML(f"""
                          <div style="text-align: center; margin: 0 auto;">Selected model <a href="https://huggingface.co/{current_model}">{current_model}</a></div>
                          """)
    with gr.Row(elem_id="input-panel"):

        with gr.Column(variant="panel",scale=0,elem_id="input-panel-items"):
            choose_model=gr.Dropdown(label="Select model",choices=modelnames,value=modelnames[0])
            dropdown = gr.Dropdown(label="Select Character",choices=list(character_mappings_model_wise[current_model].keys()))
            batch_size = gr.Number(label="Batch Size", minimum=0, maximum=100)
            inference_steps= gr.Slider(label="Steps",value=100,minimum=100,maximum=1000,step=100)

            btn = gr.Button("Generate",size="sm")  
      
        
    gallery = gr.Gallery(
        label="Generated images", show_label=False, elem_id="gallery"
    , columns=[10], rows=[10], object_fit="contain", height="auto",scale=1,min_width=80)

    choose_model.change(fn=switch_pipeline,inputs=[choose_model],outputs=[choosen_model,dropdown]) 
    btn.click(fn=generate,inputs=[choose_model,dropdown,batch_size,inference_steps],outputs=[gallery])
    
if __name__=='__main__':
    od_app.queue(max_size=20).launch(show_error=True)