zamborg commited on
Commit
d92334f
·
1 Parent(s): 9fa3fe8

prompting with no subreddit

Browse files
Files changed (1) hide show
  1. model.py +9 -4
model.py CHANGED
@@ -63,9 +63,14 @@ class VirTexModel():
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
  if prompt is not "":
66
- cap_tokens = self.tokenizer.encode(prompt)
67
- cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
68
- subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
 
 
 
 
 
69
 
70
  predictions: List[Dict[str, Any]] = []
71
 
@@ -90,7 +95,7 @@ class VirTexModel():
90
  else:
91
  subreddit, rest_of_caption = "", caption
92
 
93
- is_valid_subreddit = True if sub_prompt is not None else subreddit in self.valid_subs
94
 
95
 
96
  return subreddit, rest_of_caption
 
63
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
64
 
65
  if prompt is not "":
66
+ if sub_prompt is not None:
67
+ cap_tokens = self.tokenizer.encode(prompt)
68
+ cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
69
+ subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
70
+ else:
71
+ st.write("Without a specified subreddit, caption prompts will skip subreddit prediction")
72
+ #TODO fix this
73
+
74
 
75
  predictions: List[Dict[str, Any]] = []
76
 
 
95
  else:
96
  subreddit, rest_of_caption = "", caption
97
 
98
+ is_valid_subreddit = True if sub_prompt is not None or prompt is not None else subreddit in self.valid_subs
99
 
100
 
101
  return subreddit, rest_of_caption