zamborg commited on
Commit
89b96b3
·
1 Parent(s): 7925ce3

updated prompt splitting

Browse files
Files changed (2) hide show
  1. app.py +25 -22
  2. model.py +6 -2
app.py CHANGED
@@ -13,21 +13,25 @@ from model import *
13
 
14
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
  with st.spinner("Generating Caption"):
16
- if sub_prompt is None and cap_prompt is not "":
17
- st.write("Without a specified subreddit we default to /r/pics")
18
- subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
 
19
  st.markdown(
20
  f"""
21
  <style>
22
  red{{
23
  color:#c62828
24
  }}
 
 
 
25
  mono{{
26
  font-family: "Inconsolata";
27
  }}
28
  </style>
29
 
30
- ### <red> r/{subreddit} </red> {caption}
31
  """,
32
  unsafe_allow_html=True)
33
 
@@ -109,31 +113,30 @@ if advanced:
109
  nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
110
  virtexModel.model.decoder.nucleus_size = nuc_size
111
 
112
- if False: #uploaded_image is None:# and submitted:
113
- st.write("Please select a file to upload")
114
 
 
 
 
 
 
115
  else:
116
- image_file = sample_image
117
 
118
- # LOAD AND CACHE THE IMAGE
119
- if uploaded_image is not None:
120
- image = uploaded_image
121
- elif select_idx is None and 'image' in st.session_state:
122
- image = st.session_state['image']
123
- else:
124
- image = Image.open(image_file)
125
 
126
- image = image.convert("RGB")
127
 
128
- st.session_state['image'] = image
129
 
 
130
 
131
- image_dict = imageLoader.transform(image)
132
 
133
- show_image = imageLoader.show_resize(image)
 
134
 
135
- show = st.image(show_image)
136
- show.image(show_image, "Your Image")
137
 
138
- for i in range(num_captions):
139
- gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
 
 
 
13
 
14
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
  with st.spinner("Generating Caption"):
16
+ if cap_prompt is not "":
17
+ subreddit, prompt, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
18
+ else:
19
+ subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
20
  st.markdown(
21
  f"""
22
  <style>
23
  red{{
24
  color:#c62828
25
  }}
26
+ blue{{
27
+ color:#2a72d5
28
+ }}
29
  mono{{
30
  font-family: "Inconsolata";
31
  }}
32
  </style>
33
 
34
+ ### <red> r/{subreddit} </red> <blue> {prompt} </blue> {caption}
35
  """,
36
  unsafe_allow_html=True)
37
 
 
113
  nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
114
  virtexModel.model.decoder.nucleus_size = nuc_size
115
 
116
+ image_file = sample_image
 
117
 
118
+ # LOAD AND CACHE THE IMAGE
119
+ if uploaded_image is not None:
120
+ image = uploaded_image
121
+ elif select_idx is None and 'image' in st.session_state:
122
+ image = st.session_state['image']
123
  else:
124
+ image = Image.open(image_file)
125
 
126
+ image = image.convert("RGB")
 
 
 
 
 
 
127
 
128
+ st.session_state['image'] = image
129
 
 
130
 
131
+ image_dict = imageLoader.transform(image)
132
 
133
+ show_image = imageLoader.show_resize(image)
134
 
135
+ show = st.image(show_image)
136
+ show.image(show_image, "Your Image")
137
 
 
 
138
 
139
+ if sub is None and imageLoader.text_transform(cap_prompt) is not "":
140
+ st.write("Without a specified subreddit we default to /r/pics")
141
+ for i in range(num_captions):
142
+ gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
model.py CHANGED
@@ -92,7 +92,6 @@ class VirTexModel():
92
  subreddit_tokens = torch.cat(
93
  [
94
  subreddit_tokens,
95
- torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(),
96
  cap_tokens
97
  ])
98
 
@@ -119,10 +118,15 @@ class VirTexModel():
119
  rest_of_caption = rest_of_caption.strip()
120
  else:
121
  subreddit, rest_of_caption = "", caption
 
 
 
 
122
 
123
  is_valid_subreddit = subreddit in self.valid_subs
124
 
125
-
 
126
  return subreddit, rest_of_caption
127
 
128
  def download_files():
 
92
  subreddit_tokens = torch.cat(
93
  [
94
  subreddit_tokens,
 
95
  cap_tokens
96
  ])
97
 
 
118
  rest_of_caption = rest_of_caption.strip()
119
  else:
120
  subreddit, rest_of_caption = "", caption
121
+
122
+ # split prompt for coloring:
123
+ if prompt is not "":
124
+ _, caption = caption.split(prompt)
125
 
126
  is_valid_subreddit = subreddit in self.valid_subs
127
 
128
+ if prompt is not "":
129
+ return subreddit, prompt, rest_of_caption
130
  return subreddit, rest_of_caption
131
 
132
  def download_files():