efeperro commited on
Commit
a14156a
·
verified ·
1 Parent(s): 21c4bf6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from transformers import T5Tokenizer, ViTFeatureExtractor
6
+
7
+ # Model loading and setting up the device
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model = torch.load("model_vit_ai.pt", map_location=device)
10
+ model.to(device)
11
+
12
+ # Tokenizer and Feature Extractor
13
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
14
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
15
+
16
+ # Define the image preprocessing
17
+ transform = transforms.Compose([
18
+ transforms.Resize((224, 224)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
21
+ ])
22
+
23
+ def preprocess_image(image):
24
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
25
+ image = transform(image)
26
+ return image.unsqueeze(0)
27
+
28
+ def generate_caption(image):
29
+ model.eval()
30
+ with torch.no_grad():
31
+ image_tensor = preprocess_image(image).to(device)
32
+ decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device)
33
+
34
+ for _ in range(50):
35
+ outputs = model(images=image_tensor, decoder_ids=decoder_input_ids)
36
+ next_token_logits = outputs.logits[:, -1, :]
37
+ next_token_id = next_token_logits.argmax(1, keepdim=True)
38
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
39
+
40
+ if torch.eq(next_token_id, tokenizer.eos_token_id).all():
41
+ break
42
+
43
+ caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True)
44
+ return caption
45
+
46
+ sample_images = [
47
+ "sample_image1.jpg",
48
+ "sample_image2.jpg",
49
+ "sample_image3.jpg"
50
+ ]
51
+
52
+ # Define Gradio interface
53
+ interface = gr.Interface(
54
+ fn=generate_caption,
55
+ inputs=gr.inputs.Image(source="upload", tool='editor', type="numpy", label="Upload an image or take a photo"),
56
+ outputs='text',
57
+ examples=sample_images,
58
+ title="Image Captioning Model",
59
+ description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption."
60
+ )
61
+
62
+ # Run the interface
63
+ interface.launch(debug=True)