bobber commited on
Commit
83f96b0
·
verified ·
1 Parent(s): eb70836

Update app.py

Browse files

add vlm_prompt

Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -11,7 +11,7 @@ import os
11
 
12
 
13
  CLIP_PATH = "google/siglip-so400m-patch14-384"
14
- VLM_PROMPT = "A descriptive caption for this image:\n"
15
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
16
  CHECKPOINT_PATH = Path("wpkklhc6")
17
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
@@ -63,7 +63,7 @@ image_adapter.to("cuda")
63
 
64
  @spaces.GPU()
65
  @torch.no_grad()
66
- def stream_chat(input_image: Image.Image):
67
  torch.cuda.empty_cache()
68
 
69
  # Preprocess image
@@ -71,7 +71,10 @@ def stream_chat(input_image: Image.Image):
71
  image = image.to('cuda')
72
 
73
  # Tokenize the prompt
74
- prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
 
 
 
75
 
76
  # Embed image
77
  with torch.amp.autocast_mode.autocast('cuda', enabled=True):
@@ -121,8 +124,18 @@ with gr.Blocks() as demo:
121
 
122
  with gr.Column():
123
  output_caption = gr.Textbox(label="Caption")
 
 
 
 
 
 
 
 
 
 
124
 
125
- run_button.click(fn=stream_chat, inputs=[input_image], outputs=[output_caption])
126
 
127
 
128
  if __name__ == "__main__":
 
11
 
12
 
13
  CLIP_PATH = "google/siglip-so400m-patch14-384"
14
+ VLM_PROMPT = "A descriptive caption for this image:"
15
  MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
16
  CHECKPOINT_PATH = Path("wpkklhc6")
17
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
 
63
 
64
  @spaces.GPU()
65
  @torch.no_grad()
66
+ def stream_chat(input_image: Image.Image, vlm_prompt):
67
  torch.cuda.empty_cache()
68
 
69
  # Preprocess image
 
71
  image = image.to('cuda')
72
 
73
  # Tokenize the prompt
74
+ if not vlm_prompt:
75
+ vlm_prompt = VLM_PROMPT
76
+ vlm_prompt = vlm_prompt + "\n"
77
+ prompt = tokenizer.encode(vlm_prompt, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
78
 
79
  # Embed image
80
  with torch.amp.autocast_mode.autocast('cuda', enabled=True):
 
124
 
125
  with gr.Column():
126
  output_caption = gr.Textbox(label="Caption")
127
+
128
+ with gr.Row():
129
+ vlm_prompt = gr.Text(
130
+ label="VLM Prompt",
131
+ show_label=False,
132
+ max_lines=1,
133
+ placeholder="Enter your VLM prompt",
134
+ container=False,
135
+ value="A descriptive caption for this image:",
136
+ )
137
 
138
+ run_button.click(fn=stream_chat, inputs=[input_image, vlm_prompt], outputs=[output_caption])
139
 
140
 
141
  if __name__ == "__main__":