Spaces:
Sleeping
Sleeping
VenkateshRoshan
commited on
Commit
·
bf9aafc
1
Parent(s):
3138612
inference code added
Browse files
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
from infer import ImageCaptioningInference
|
5 |
+
from models.model import ImageCaptioningModel
|
6 |
+
|
7 |
+
app = Flask(__name__)
|
8 |
+
|
9 |
+
model_dir = 'model'
|
10 |
+
|
11 |
+
# Initialize inference class
|
12 |
+
model = ImageCaptioningModel()
|
13 |
+
model.load(model_dir)
|
14 |
+
|
15 |
+
inference_model = ImageCaptioningInference(model)
|
16 |
+
|
17 |
+
# # Path to the input image
|
18 |
+
# image_path = 'test_img.jpg'
|
19 |
+
|
20 |
+
# # Perform inference and print the generated caption
|
21 |
+
# caption = inference_model.infer_image(image_path)
|
22 |
+
# print("Generated Caption:", caption)
|
23 |
+
|
24 |
+
@app.route('/')
|
25 |
+
def home():
|
26 |
+
return "Welcome to the Flask API"
|
27 |
+
|
28 |
+
@app.route('/upload-image', methods=['POST'])
|
29 |
+
def upload_image():
|
30 |
+
if 'image' not in request.files:
|
31 |
+
return jsonify({'error': 'No image found in request'})
|
32 |
+
|
33 |
+
image = request.files['image']
|
34 |
+
# print(image)
|
35 |
+
|
36 |
+
# try :
|
37 |
+
image = Image.open(io.BytesIO(image.read()))
|
38 |
+
# image.show()
|
39 |
+
|
40 |
+
generated_caption = inference_model.infer_image(image)
|
41 |
+
|
42 |
+
return jsonify({'generated_caption': generated_caption})
|
43 |
+
|
44 |
+
|
45 |
+
# except Exception as e:
|
46 |
+
# return jsonify({'error': f'{e}'}), 500
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
app.run(debug=True)
|
infer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from models.model import ImageCaptioningModel
|
3 |
+
from torchvision import transforms
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer
|
8 |
+
from PIL import Image
|
9 |
+
from config.config import Config
|
10 |
+
|
11 |
+
class ImageCaptioningInference:
|
12 |
+
def __init__(self, model):
|
13 |
+
self.model = model
|
14 |
+
self.device = Config.DEVICE
|
15 |
+
self.transform = transforms.Compose([
|
16 |
+
transforms.Resize((224, 224)),
|
17 |
+
transforms.ToTensor()
|
18 |
+
])
|
19 |
+
|
20 |
+
def infer_image(self, image):
|
21 |
+
# Load and preprocess the image
|
22 |
+
# image = Image.open(image_path)
|
23 |
+
image = self.transform(image).unsqueeze(0).to(self.device)
|
24 |
+
|
25 |
+
# Extract image features
|
26 |
+
image_features = self.model.extract_image_features(image)
|
27 |
+
|
28 |
+
# Generate caption
|
29 |
+
caption = self.generate_caption(image_features)
|
30 |
+
return caption
|
31 |
+
|
32 |
+
def generate_caption(self, image_features, num_beams=3, max_length=50):
|
33 |
+
# Prepare the image features for input
|
34 |
+
image_features = image_features.unsqueeze(1) # [batch_size, 1, hidden_size]
|
35 |
+
|
36 |
+
# Generate caption using beam search
|
37 |
+
output = self.model.gpt2_model.generate(
|
38 |
+
inputs_embeds=image_features,
|
39 |
+
max_length=max_length,
|
40 |
+
num_beams=num_beams,
|
41 |
+
early_stopping=True,
|
42 |
+
pad_token_id=self.model.tokenizer.eos_token_id,
|
43 |
+
bos_token_id=self.model.tokenizer.bos_token_id,
|
44 |
+
eos_token_id=self.model.tokenizer.eos_token_id
|
45 |
+
)
|
46 |
+
|
47 |
+
# Decode the generated caption
|
48 |
+
caption = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
|
49 |
+
return caption
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
# Path to the saved model directory
|
53 |
+
model_dir = 'model'
|
54 |
+
|
55 |
+
# Initialize inference class
|
56 |
+
model = ImageCaptioningModel()
|
57 |
+
model.load(model_dir)
|
58 |
+
|
59 |
+
inference_model = ImageCaptioningInference(model)
|
60 |
+
|
61 |
+
# Path to the input image
|
62 |
+
image_path = 'test_img.jpg'
|
63 |
+
|
64 |
+
image = Image.open(image_path)
|
65 |
+
|
66 |
+
# Perform inference and print the generated caption
|
67 |
+
caption = inference_model.infer_image(image)
|
68 |
+
print("Generated Caption:", caption)
|
69 |
+
|
test_img.jpg
ADDED
![]() |
train.py
CHANGED
@@ -7,12 +7,22 @@ from data.dataLoader import ImageCaptionDataset
|
|
7 |
from config.config import Config
|
8 |
from models.model import ImageCaptioningModel
|
9 |
|
10 |
-
|
|
|
11 |
|
|
|
12 |
|
13 |
|
14 |
def train_model(model,dataLoader, optimizer, loss_fn):
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
model.gpt2_model.train()
|
17 |
for epoch in range(Config.EPOCHS):
|
18 |
epoch_loss = 0
|
@@ -41,10 +51,14 @@ def train_model(model,dataLoader, optimizer, loss_fn):
|
|
41 |
|
42 |
epoch_loss += loss.item()
|
43 |
|
44 |
-
print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
|
|
|
45 |
|
46 |
# Save the model
|
47 |
model.save('model')
|
|
|
|
|
|
|
48 |
|
49 |
# return model
|
50 |
|
@@ -75,4 +89,5 @@ if __name__ == '__main__':
|
|
75 |
model = ImageCaptioningModel()
|
76 |
optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
|
77 |
loss_fn = torch.nn.CrossEntropyLoss()
|
|
|
78 |
train_model(model, dataloader, optimizer, loss_fn)
|
|
|
7 |
from config.config import Config
|
8 |
from models.model import ImageCaptioningModel
|
9 |
|
10 |
+
import mlflow
|
11 |
+
import mlflow.pytorch
|
12 |
|
13 |
+
# TODO : Implementing Weights and Biases to for project tracking and evaluation and TODO : DVC also for data versioning
|
14 |
|
15 |
|
16 |
def train_model(model,dataLoader, optimizer, loss_fn):
|
17 |
|
18 |
+
with mlflow.start_run():
|
19 |
+
mlflow.log_params({
|
20 |
+
"epochs": Config.EPOCHS,
|
21 |
+
"batch_size": Config.BATCH_SIZE,
|
22 |
+
"learning_rate": Config.LEARNING_RATE,
|
23 |
+
"device": Config.DEVICE
|
24 |
+
})
|
25 |
+
|
26 |
model.gpt2_model.train()
|
27 |
for epoch in range(Config.EPOCHS):
|
28 |
epoch_loss = 0
|
|
|
51 |
|
52 |
epoch_loss += loss.item()
|
53 |
|
54 |
+
print(f'Epoch {epoch + 1}, Loss: {epoch_loss/len(dataLoader):.4f}')
|
55 |
+
mlflow.log_metric('loss', epoch_loss/len(dataLoader), step=epoch)
|
56 |
|
57 |
# Save the model
|
58 |
model.save('model')
|
59 |
+
# save the artifacts
|
60 |
+
mlflow.log_artifacts('model')
|
61 |
+
mlflow.pytorch.log_model(model.gpt2_model, "models")
|
62 |
|
63 |
# return model
|
64 |
|
|
|
89 |
model = ImageCaptioningModel()
|
90 |
optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
|
91 |
loss_fn = torch.nn.CrossEntropyLoss()
|
92 |
+
mlflow.set_experiment('ImageCaptioning')
|
93 |
train_model(model, dataloader, optimizer, loss_fn)
|