zamborg commited on
Commit
737c5eb
·
1 Parent(s): ad34fc1

prompt fixing

Browse files
Files changed (2) hide show
  1. app.py +2 -5
  2. model.py +2 -4
app.py CHANGED
@@ -13,10 +13,7 @@ from model import *
13
 
14
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
  with st.spinner("Generating Caption"):
16
- if cap_prompt is not "":
17
- subreddit, prompt, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
18
- else:
19
- subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
20
  st.markdown(
21
  f"""
22
  <style>
@@ -31,7 +28,7 @@ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
31
  }}
32
  </style>
33
 
34
- ### <red> r/{subreddit} </red> <blue> {prompt} </blue> {caption}
35
  """,
36
  unsafe_allow_html=True)
37
 
 
13
 
14
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
  with st.spinner("Generating Caption"):
16
+ subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt=cap_prompt)
 
 
 
17
  st.markdown(
18
  f"""
19
  <style>
 
28
  }}
29
  </style>
30
 
31
+ ### <red> r/{subreddit} </red> <blue> {cap_prompt} </blue> {caption}
32
  """,
33
  unsafe_allow_html=True)
34
 
model.py CHANGED
@@ -117,16 +117,14 @@ class VirTexModel():
117
  subreddit = "".join(subreddit.split())
118
  rest_of_caption = rest_of_caption.strip()
119
  else:
120
- subreddit, rest_of_caption = "", caption
121
 
122
  # split prompt for coloring:
123
  if prompt is not "":
124
- _, caption = caption.split(prompt)
125
 
126
  is_valid_subreddit = subreddit in self.valid_subs
127
 
128
- if prompt is not "":
129
- return subreddit, prompt, rest_of_caption
130
  return subreddit, rest_of_caption
131
 
132
  def download_files():
 
117
  subreddit = "".join(subreddit.split())
118
  rest_of_caption = rest_of_caption.strip()
119
  else:
120
+ subreddit, rest_of_caption = "", caption.strip()
121
 
122
  # split prompt for coloring:
123
  if prompt is not "":
124
+ _, rest_of_caption = caption.split(prompt.strip())
125
 
126
  is_valid_subreddit = subreddit in self.valid_subs
127
 
 
 
128
  return subreddit, rest_of_caption
129
 
130
  def download_files():