3laa2 commited on
Commit
d90716a
·
1 Parent(s): ef21c7e

Update app.py

Browse files

Added GPT Prompts generator

Files changed (1) hide show
  1. app.py +57 -3
app.py CHANGED
@@ -3,6 +3,7 @@ import cv2 as cv
3
  import time
4
  import torch
5
  from diffusers import StableDiffusionPipeline
 
6
 
7
 
8
  def create_model(loc = "stabilityai/stable-diffusion-2-1-base", mch = 'cpu'):
@@ -10,6 +11,14 @@ def create_model(loc = "stabilityai/stable-diffusion-2-1-base", mch = 'cpu'):
10
  pipe = pipe.to(mch)
11
  return pipe
12
 
 
 
 
 
 
 
 
 
13
  t2i = st.title("""
14
  Txt2Img
15
  ###### `CLICK "Create_Update_Model"` :
@@ -18,17 +27,62 @@ Txt2Img
18
 
19
  the_type = st.selectbox("Model",("stabilityai/stable-diffusion-2-1-base",
20
  "CompVis/stable-diffusion-v1-4"))
 
 
 
21
 
22
- create = st.button("Create The Model")
 
23
 
24
  if create:
25
  st.session_state.t2m_mod = create_model(loc=the_type)
26
 
27
- prom = st.text_input("# Prompt",'')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  c1,c2,c3 = st.columns([1,1,3])
30
  c4,c5 = st.columns(2)
31
-
32
  with c1:
33
  bu_1 = st.text_input("Seed",'999')
34
  with c2:
 
3
  import time
4
  import torch
5
  from diffusers import StableDiffusionPipeline
6
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
7
 
8
 
9
  def create_model(loc = "stabilityai/stable-diffusion-2-1-base", mch = 'cpu'):
 
11
  pipe = pipe.to(mch)
12
  return pipe
13
 
14
+
15
+ def tok_mod():
16
+ tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
17
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
18
+ model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')
19
+ return model,tokenizer
20
+
21
+
22
  t2i = st.title("""
23
  Txt2Img
24
  ###### `CLICK "Create_Update_Model"` :
 
27
 
28
  the_type = st.selectbox("Model",("stabilityai/stable-diffusion-2-1-base",
29
  "CompVis/stable-diffusion-v1-4"))
30
+ st.session_state.gate = False
31
+
32
+ ma_1,_,ma_2 = st.columns([1,3,1])
33
 
34
+ with ma_1 :
35
+ create = st.button("Create The Model")
36
 
37
  if create:
38
  st.session_state.t2m_mod = create_model(loc=the_type)
39
 
40
+ with ma_2 :
41
+ gpt = st.checkbox("GPT PROMS")
42
+
43
+ if gpt :
44
+ gen = st.button("Create GPT Model")
45
+ if gen:
46
+ st.session_state.mod,st.session_state.tok = tok_mod()
47
+
48
+ m1,m2,m3 = st.columns([1,1,3])
49
+ m4,m5 = st.columns(2)
50
+ prompt = st.text_input("GPT PROM",r'' )
51
+ with m1 :
52
+ temperature = st.slider("Temp",0.0,1.0,.9,.1)
53
+ with m2 :
54
+ top_k = st.slider("K",2,16,8,2)
55
+ with m3 :
56
+ max_length = st.slider("Length",10,100,80,1)
57
+ with m4 :
58
+ repitition_penalty = st.slider("penality",1.0,5.0,1.2,1.0)
59
+ with m5 :
60
+ num_return_sequences=st.slider("Proms Num",1,10,5,1)
61
+
62
+ prom_gen = st.button("Generate Proms")
63
 
64
+ if prom_gen :
65
+ model, tokenizer = st.session_state.mod,st.session_state.tok
66
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids
67
+ output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length,
68
+ num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty,
69
+ penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
70
+
71
+ st.session_state.PROMPTS = []
72
+ for i in range(len(output)):
73
+ st.session_state.PROMPTS.append(tokenizer.decode(output[i]))
74
+
75
+ if 'PROMPTS' in st.session_state :
76
+ prom = st.selectbox("Proms",st.session_state.PROMPTS)
77
+
78
+ else :
79
+ prom = st.text_input("# Prompt",'')
80
+
81
+
82
+
83
+
84
  c1,c2,c3 = st.columns([1,1,3])
85
  c4,c5 = st.columns(2)
 
86
  with c1:
87
  bu_1 = st.text_input("Seed",'999')
88
  with c2: