File size: 3,696 Bytes
64ea77f
232aeb3
 
038d520
 
d90716a
232aeb3
4af4d04
22c64f8
 
232aeb3
 
24d77bb
d90716a
 
 
 
 
fa3d600
d90716a
 
 
fbecb15
28c23ee
 
fbecb15
406927c
3df1457
232aeb3
28c23ee
4403c6c
d90716a
 
3df1457
4af4d04
d90716a
 
24d77bb
 
bf8eed8
0aa3d75
d90716a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa3d75
d90716a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa3d75
 
 
 
 
 
 
 
 
eb5ce46
0aa3d75
eb5ce46
0aa3d75
ef21c7e
 
0aa3d75
4af4d04
0aa3d75
4af4d04
 
a71b15c
 
 
 
 
 
4af4d04
a71b15c
 
4af4d04
dfa0ea3
a71b15c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import cv2 as cv
import time
import torch
from diffusers import StableDiffusionPipeline
from transformers import GPT2Tokenizer, GPT2LMHeadModel


def create_model(loc = "stabilityai/stable-diffusion-2-1-base", mch = 'cpu'):
    pipe = StableDiffusionPipeline.from_pretrained(loc)
    pipe = pipe.to(mch)
    return pipe


def tok_mod():
  tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
  model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')
  model.to('cpu')
  return model,tokenizer


t2i = st.title("""
Txt2Img
###### `CLICK "Create_Update_Model"` :
- `FIRST RUN OF THE CODE`
- `CHANGING MODEL`
###### TO USE GPT PROMPTS GENERATOR CHECK `GPT PROMS` THEN CLICK `CREATE GPT MODEL`""")

the_type = st.selectbox("Model",("stabilityai/stable-diffusion-2-1-base",
                                      "CompVis/stable-diffusion-v1-4"))
st.session_state.gate = False

ma_1,_,ma_2 = st.columns([2,2,2])

with ma_1 :
  create = st.button("Create The Model")

if create:
    st.session_state.t2m_mod = create_model(loc=the_type)

with ma_2 : 
  gpt = st.checkbox("GPT PROMS")

if gpt :
    gen = st.button("Create GPT Model")
    if gen:
        st.session_state.mod,st.session_state.tok = tok_mod()

    m1,m2,m3 = st.columns([1,1,3])
    m4,m5 = st.columns(2)
    prompt = st.text_input("GPT PROM",r'' )
    with m1 :
      temperature = st.slider("Temp",0.0,1.0,.9,.1)   
    with m2 :        
      top_k = st.slider("K",2,16,8,2)  
    with m3 :                                               
      max_length = st.slider("Length",10,100,80,1)                                        
    with m4 :
      repitition_penalty = st.slider("penality",1.0,5.0,1.2,1.0)                                
    with m5 :
      num_return_sequences=st.slider("Proms Num",1,10,5,1)

    prom_gen = st.button("Generate Proms")

    if prom_gen :
        model, tokenizer = st.session_state.mod,st.session_state.tok
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids
        output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length,
                                num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty,
                                penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)

        st.session_state.PROMPTS = []
        for i in range(len(output)):
          st.session_state.PROMPTS.append(tokenizer.decode(output[i]))

if 'PROMPTS' in st.session_state :
    prom = st.selectbox("Proms",st.session_state.PROMPTS)

else :
    prom = st.text_input("# Prompt",'')



                               
c1,c2,c3 = st.columns([1,1,3])
c4,c5 = st.columns(2)
with c1:
  bu_1 = st.text_input("Seed",'999')
with c2:
  bu_2 = st.text_input("Steps",'12')
with c3:
  bu_3 = st.text_input("Number of Images",'1')
with c4:
  sl_1 = st.slider("Width",128,1024,512,8)
with c5:
  sl_2 = st.slider("hight",128,1024,512,8)

st.session_state.generator = torch.Generator("cpu").manual_seed(int(bu_1))

create = st.button("Imagine")

if create:
    model = st.session_state.t2m_mod
    generator = st.session_state.generator

    if int(bu_3) == 1 :
      IMG = model(prom, width=int(sl_1), height=int(sl_2),
                    num_inference_steps=int(bu_2),
                    generator=generator).images[0]
      st.image(IMG)
        
    else :
      PROMS = [prom]*int(bu_3)
        
      IMGS = model(PROMS, width=int(sl_1), height=int(sl_2),
                     num_inference_steps=int(bu_2),
                     generator=generator).images
    
      st.image(IMGS)