Spaces:
Paused
Paused
fancyfeast
commited on
Commit
•
5d57e40
1
Parent(s):
f73cf3f
Improve handling caption tone special case. Also, derp, forgot to format the prompt string.
Browse files
app.py
CHANGED
@@ -144,12 +144,20 @@ image_adapter.to("cuda")
|
|
144 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
|
145 |
torch.cuda.empty_cache()
|
146 |
|
|
|
147 |
length = None if caption_length == "any" else caption_length
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
149 |
if prompt_key not in CAPTION_TYPE_MAP:
|
150 |
raise ValueError(f"Invalid caption type: {prompt_key}")
|
151 |
|
152 |
-
prompt_str = CAPTION_TYPE_MAP[prompt_key][0]
|
|
|
153 |
|
154 |
# Preprocess image
|
155 |
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
@@ -230,6 +238,8 @@ with gr.Blocks() as demo:
|
|
230 |
value="any",
|
231 |
)
|
232 |
|
|
|
|
|
233 |
run_button = gr.Button("Caption")
|
234 |
|
235 |
with gr.Column():
|
|
|
144 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
|
145 |
torch.cuda.empty_cache()
|
146 |
|
147 |
+
# 'any' means no length specified
|
148 |
length = None if caption_length == "any" else caption_length
|
149 |
+
|
150 |
+
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
|
151 |
+
if caption_type == "rng-tags" or caption_type == "training_prompt":
|
152 |
+
caption_tone = "formal"
|
153 |
+
|
154 |
+
# Build prompt
|
155 |
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
156 |
if prompt_key not in CAPTION_TYPE_MAP:
|
157 |
raise ValueError(f"Invalid caption type: {prompt_key}")
|
158 |
|
159 |
+
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
160 |
+
print(f"Prompt: {prompt_str}")
|
161 |
|
162 |
# Preprocess image
|
163 |
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
|
|
238 |
value="any",
|
239 |
)
|
240 |
|
241 |
+
gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
|
242 |
+
|
243 |
run_button = gr.Button("Caption")
|
244 |
|
245 |
with gr.Column():
|