LHRuig commited on
Commit
bc2f7fd
·
verified ·
1 Parent(s): dc0e6a4

Delete caption.py

Browse files
Files changed (1) hide show
  1. caption.py +0 -20
caption.py DELETED
@@ -1,20 +0,0 @@
1
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
2
- from PIL import Image
3
- import torch
4
-
5
- def generate_caption(image_path, trigger_word):
6
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
7
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model.to(device)
10
-
11
- image = Image.open(image_path)
12
- inputs = processor(image, return_tensors="pt").to(device, torch.float16)
13
- generated_ids = model.generate(**inputs, max_new_tokens=50)
14
- caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
15
-
16
- return f"a photo of [{trigger_word}], {caption}"
17
-
18
- # Example:
19
- caption = generate_caption("image.jpg", "my_char")
20
- print(caption) # Output: "a photo of [my_char], a woman smiling in a park"