mrbeliever commited on
Commit
f07f9e1
·
verified ·
1 Parent(s): b1eb0f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -20,7 +20,7 @@ DEFAULT_QUERY = (
20
  "Avoid subjective interpretations or speculation."
21
  )
22
 
23
- DTYPE = torch.float16 # Use float16 for faster processing on CPU with limited resources
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  tokenizer = LlamaTokenizer.from_pretrained(
@@ -39,12 +39,11 @@ model = model.to(device=DEVICE)
39
  @torch.no_grad()
40
  def generate_caption(
41
  image: Image.Image,
42
- query: str = DEFAULT_QUERY,
43
  params: dict[str, Any] = DEFAULT_PARAMS,
44
  ) -> str:
45
  inputs = model.build_conversation_input_ids(
46
  tokenizer=tokenizer,
47
- query=query,
48
  history=[],
49
  images=[image],
50
  )
@@ -56,7 +55,7 @@ def generate_caption(
56
  }
57
 
58
  outputs = model.generate(**inputs, **params)
59
- outputs = outputs[:, inputs["input_ids"].shape[1]:]
60
  result = tokenizer.decode(outputs[0])
61
 
62
  result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
@@ -65,14 +64,14 @@ def generate_caption(
65
  with gr.Blocks() as demo:
66
  with gr.Row():
67
  with gr.Column():
68
- input_image = gr.Image(type="pil")
69
  run_button = gr.Button(value="Generate Caption")
70
  with gr.Column():
71
  output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True)
72
 
73
  run_button.click(
74
  fn=generate_caption,
75
- inputs=[input_image], # Only input image is needed
76
  outputs=output_caption,
77
  )
78
 
 
20
  "Avoid subjective interpretations or speculation."
21
  )
22
 
23
+ DTYPE = torch.bfloat16
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  tokenizer = LlamaTokenizer.from_pretrained(
 
39
  @torch.no_grad()
40
  def generate_caption(
41
  image: Image.Image,
 
42
  params: dict[str, Any] = DEFAULT_PARAMS,
43
  ) -> str:
44
  inputs = model.build_conversation_input_ids(
45
  tokenizer=tokenizer,
46
+ query=DEFAULT_QUERY, # Use the default query directly
47
  history=[],
48
  images=[image],
49
  )
 
55
  }
56
 
57
  outputs = model.generate(**inputs, **params)
58
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
59
  result = tokenizer.decode(outputs[0])
60
 
61
  result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
 
64
  with gr.Blocks() as demo:
65
  with gr.Row():
66
  with gr.Column():
67
+ input_image = gr.Image(type="pil") # Image input remains
68
  run_button = gr.Button(value="Generate Caption")
69
  with gr.Column():
70
  output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True)
71
 
72
  run_button.click(
73
  fn=generate_caption,
74
+ inputs=[input_image], # Only the image input is passed
75
  outputs=output_caption,
76
  )
77