ayushrupapara commited on
Commit
77675f2
·
verified ·
1 Parent(s): dda927f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -97
app.py CHANGED
@@ -1,98 +1,98 @@
1
- # app.py (Content of this file should be your 'gradio_code_debugged_v2' from previous steps)
2
- import gradio as gr
3
- import torch
4
- import torch.nn as nn
5
- from transformers import ViTModel, GPT2LMHeadModel, GPT2TokenizerFast, ViTFeatureExtractor, GPT2Config
6
- from huggingface_hub import hf_hub_download
7
- from PIL import Image
8
- import asyncio
9
- import concurrent.futures
10
-
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- # Load Model & Tokenizer
14
- class ViT_GPT2_Captioner(nn.Module):
15
- def __init__(self):
16
- super(ViT_GPT2_Captioner, self).__init__()
17
- self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
18
- gpt2_config = GPT2Config.from_pretrained('gpt2')
19
- gpt2_config.add_cross_attention = True
20
- self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2', config=gpt2_config)
21
- self.bridge = nn.Linear(self.vit.config.hidden_size, self.gpt2.config.n_embd)
22
- for param in self.vit.parameters():
23
- param.requires_grad = False
24
-
25
- def forward(self, pixel_values, captions, attention_mask=None):
26
- visual_features = self.vit(pixel_values=pixel_values).last_hidden_state
27
- projected_features = self.bridge(visual_features[:, 0, :])
28
- outputs = self.gpt2(input_ids=captions, attention_mask=attention_mask,
29
- encoder_hidden_states=projected_features.unsqueeze(1),
30
- encoder_attention_mask=torch.ones(projected_features.size(0), 1).to(projected_features.device))
31
- return outputs.logits
32
-
33
- model_path = hf_hub_download(repo_id="ayushrupapara/vit-gpt2-flickr8k-image-captioner", filename="model.pth") # Correct repo_id
34
- model = ViT_GPT2_Captioner().to(device)
35
- model.load_state_dict(torch.load(model_path, map_location=device))
36
- model.eval()
37
-
38
- tokenizer = GPT2TokenizerFast.from_pretrained("ayushrupapara/vit-gpt2-flickr8k-image-captioner", force_download=True) # Correct repo_id
39
- tokenizer.pad_token = tokenizer.eos_token
40
- feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
41
-
42
- import asyncio
43
- import concurrent.futures
44
-
45
- executor = concurrent.futures.ThreadPoolExecutor()
46
-
47
- # beam search with tunning
48
- async def generate_caption_async(image, num_beams, temperature):
49
- loop = asyncio.get_event_loop()
50
- return await loop.run_in_executor(executor, generate_caption_sync, image, num_beams, temperature)
51
-
52
- def generate_caption_sync(image, num_beams=5, temperature=0.5, max_length=20):
53
- #print(f"Received max_length: {max_length}, Type: {type(max_length)}")
54
- max_length = int(max_length)
55
- #print(f"Max_length after int conversion: {max_length}, Type: {type(max_length)}")
56
-
57
-
58
- if image is None:
59
- return "No image uploaded"
60
- if isinstance(image, Image.Image):
61
- image = image.convert("RGB")
62
- else:
63
- raise TypeError("Invalid image format. Expected a PIL Image.")
64
-
65
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
66
-
67
- with torch.no_grad():
68
- input_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
69
- output_ids = model.gpt2.generate( # Using model.gpt2.generate for beam search
70
- inputs=input_ids,
71
- encoder_hidden_states=model.bridge(model.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]).unsqueeze(1),
72
- max_length=max_length,
73
- num_beams=num_beams,
74
- temperature=temperature,
75
- length_penalty=0.9,
76
- no_repeat_ngram_size=2,
77
- early_stopping=True,
78
- pad_token_id=tokenizer.eos_token_id,
79
- eos_token_id=tokenizer.eos_token_id,
80
- )
81
-
82
- caption = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
83
- return caption
84
-
85
-
86
- iface = gr.Interface(fn=generate_caption_async,
87
- inputs=[
88
- gr.Image(type="pil"),
89
- gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Beams (num_beams)"),
90
- gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.7, label="Temperature")
91
- ],
92
- outputs="text",
93
- title="ViT-GPT2 Image Captioning",
94
- description="Upload an image to get a caption.")
95
-
96
-
97
-
98
  iface.launch() # Removed debug=True for deployment
 
1
+ # app.py (Content of this file should be your 'gradio_code_debugged_v2' from previous steps)
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import ViTModel, GPT2LMHeadModel, GPT2TokenizerFast, ViTFeatureExtractor, GPT2Config
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image
8
+ import asyncio
9
+ import concurrent.futures
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Load Model & Tokenizer
14
+ class ViT_GPT2_Captioner(nn.Module):
15
+ def __init__(self):
16
+ super(ViT_GPT2_Captioner, self).__init__()
17
+ self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
18
+ gpt2_config = GPT2Config.from_pretrained('gpt2')
19
+ gpt2_config.add_cross_attention = True
20
+ self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2', config=gpt2_config)
21
+ self.bridge = nn.Linear(self.vit.config.hidden_size, self.gpt2.config.n_embd)
22
+ for param in self.vit.parameters():
23
+ param.requires_grad = False
24
+
25
+ def forward(self, pixel_values, captions, attention_mask=None):
26
+ visual_features = self.vit(pixel_values=pixel_values).last_hidden_state
27
+ projected_features = self.bridge(visual_features[:, 0, :])
28
+ outputs = self.gpt2(input_ids=captions, attention_mask=attention_mask,
29
+ encoder_hidden_states=projected_features.unsqueeze(1),
30
+ encoder_attention_mask=torch.ones(projected_features.size(0), 1).to(projected_features.device))
31
+ return outputs.logits
32
+
33
+ model_path = hf_hub_download(repo_id="ayushrupapara/vit-gpt2-flickr8k-image-captioner", filename="model.pth") # Correct repo_id
34
+ model = ViT_GPT2_Captioner().to(device)
35
+ model.load_state_dict(torch.load(model_path, map_location=device))
36
+ model.eval()
37
+
38
+ tokenizer = GPT2TokenizerFast.from_pretrained("ayushrupapara/vit-gpt2-flickr8k-image-captioner", force_download=True) # Correct repo_id
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
41
+
42
+ import asyncio
43
+ import concurrent.futures
44
+
45
+ executor = concurrent.futures.ThreadPoolExecutor()
46
+
47
+ # beam search with tunning
48
+ async def generate_caption_async(image, num_beams, temperature):
49
+ loop = asyncio.get_event_loop()
50
+ return await loop.run_in_executor(executor, generate_caption_sync, image, num_beams, temperature)
51
+
52
+ def generate_caption_sync(image, num_beams=5, temperature=0.5, max_length=20):
53
+ #print(f"Received max_length: {max_length}, Type: {type(max_length)}")
54
+ max_length = int(max_length)
55
+ #print(f"Max_length after int conversion: {max_length}, Type: {type(max_length)}")
56
+
57
+
58
+ if image is None:
59
+ return "No image uploaded"
60
+ if isinstance(image, Image.Image):
61
+ image = image.convert("RGB")
62
+ else:
63
+ raise TypeError("Invalid image format. Expected a PIL Image.")
64
+
65
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
66
+
67
+ with torch.no_grad():
68
+ input_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
69
+ output_ids = model.gpt2.generate( # Using model.gpt2.generate for beam search
70
+ inputs=input_ids,
71
+ encoder_hidden_states=model.bridge(model.vit(pixel_values=pixel_values).last_hidden_state[:, 0, :]).unsqueeze(1),
72
+ max_length=max_length,
73
+ num_beams=num_beams,
74
+ temperature=temperature,
75
+ length_penalty=0.9,
76
+ no_repeat_ngram_size=2,
77
+ early_stopping=True,
78
+ pad_token_id=tokenizer.eos_token_id,
79
+ eos_token_id=tokenizer.eos_token_id,
80
+ )
81
+
82
+ caption = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
83
+ return caption
84
+
85
+
86
+ iface = gr.Interface(fn=generate_caption_async,
87
+ inputs=[
88
+ gr.Image(type="pil"),
89
+ gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Beams"),
90
+ gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.7, label="Temperature")
91
+ ],
92
+ outputs="text",
93
+ title="ViT-GPT2 Image Captioning",
94
+ description="Upload an image to get a caption.")
95
+
96
+
97
+
98
  iface.launch() # Removed debug=True for deployment