VenkateshRoshan commited on
Commit
bf9aafc
·
1 Parent(s): 3138612

inference code added

Browse files
Files changed (4) hide show
  1. app.py +49 -0
  2. infer.py +69 -0
  3. test_img.jpg +0 -0
  4. train.py +17 -2
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from PIL import Image
3
+ import io
4
+ from infer import ImageCaptioningInference
5
+ from models.model import ImageCaptioningModel
6
+
7
+ app = Flask(__name__)
8
+
9
+ model_dir = 'model'
10
+
11
+ # Initialize inference class
12
+ model = ImageCaptioningModel()
13
+ model.load(model_dir)
14
+
15
+ inference_model = ImageCaptioningInference(model)
16
+
17
+ # # Path to the input image
18
+ # image_path = 'test_img.jpg'
19
+
20
+ # # Perform inference and print the generated caption
21
+ # caption = inference_model.infer_image(image_path)
22
+ # print("Generated Caption:", caption)
23
+
24
+ @app.route('/')
25
+ def home():
26
+ return "Welcome to the Flask API"
27
+
28
+ @app.route('/upload-image', methods=['POST'])
29
+ def upload_image():
30
+ if 'image' not in request.files:
31
+ return jsonify({'error': 'No image found in request'})
32
+
33
+ image = request.files['image']
34
+ # print(image)
35
+
36
+ # try :
37
+ image = Image.open(io.BytesIO(image.read()))
38
+ # image.show()
39
+
40
+ generated_caption = inference_model.infer_image(image)
41
+
42
+ return jsonify({'generated_caption': generated_caption})
43
+
44
+
45
+ # except Exception as e:
46
+ # return jsonify({'error': f'{e}'}), 500
47
+
48
+ if __name__ == '__main__':
49
+ app.run(debug=True)
infer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from models.model import ImageCaptioningModel
3
+ from torchvision import transforms
4
+ import torch
5
+
6
+ import torch
7
+ from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer
8
+ from PIL import Image
9
+ from config.config import Config
10
+
11
+ class ImageCaptioningInference:
12
+ def __init__(self, model):
13
+ self.model = model
14
+ self.device = Config.DEVICE
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor()
18
+ ])
19
+
20
+ def infer_image(self, image):
21
+ # Load and preprocess the image
22
+ # image = Image.open(image_path)
23
+ image = self.transform(image).unsqueeze(0).to(self.device)
24
+
25
+ # Extract image features
26
+ image_features = self.model.extract_image_features(image)
27
+
28
+ # Generate caption
29
+ caption = self.generate_caption(image_features)
30
+ return caption
31
+
32
+ def generate_caption(self, image_features, num_beams=3, max_length=50):
33
+ # Prepare the image features for input
34
+ image_features = image_features.unsqueeze(1) # [batch_size, 1, hidden_size]
35
+
36
+ # Generate caption using beam search
37
+ output = self.model.gpt2_model.generate(
38
+ inputs_embeds=image_features,
39
+ max_length=max_length,
40
+ num_beams=num_beams,
41
+ early_stopping=True,
42
+ pad_token_id=self.model.tokenizer.eos_token_id,
43
+ bos_token_id=self.model.tokenizer.bos_token_id,
44
+ eos_token_id=self.model.tokenizer.eos_token_id
45
+ )
46
+
47
+ # Decode the generated caption
48
+ caption = self.model.tokenizer.decode(output[0], skip_special_tokens=True)
49
+ return caption
50
+
51
+ if __name__ == "__main__":
52
+ # Path to the saved model directory
53
+ model_dir = 'model'
54
+
55
+ # Initialize inference class
56
+ model = ImageCaptioningModel()
57
+ model.load(model_dir)
58
+
59
+ inference_model = ImageCaptioningInference(model)
60
+
61
+ # Path to the input image
62
+ image_path = 'test_img.jpg'
63
+
64
+ image = Image.open(image_path)
65
+
66
+ # Perform inference and print the generated caption
67
+ caption = inference_model.infer_image(image)
68
+ print("Generated Caption:", caption)
69
+
test_img.jpg ADDED
train.py CHANGED
@@ -7,12 +7,22 @@ from data.dataLoader import ImageCaptionDataset
7
  from config.config import Config
8
  from models.model import ImageCaptioningModel
9
 
10
- from torchsummary import summary
 
11
 
 
12
 
13
 
14
  def train_model(model,dataLoader, optimizer, loss_fn):
15
 
 
 
 
 
 
 
 
 
16
  model.gpt2_model.train()
17
  for epoch in range(Config.EPOCHS):
18
  epoch_loss = 0
@@ -41,10 +51,14 @@ def train_model(model,dataLoader, optimizer, loss_fn):
41
 
42
  epoch_loss += loss.item()
43
 
44
- print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
 
45
 
46
  # Save the model
47
  model.save('model')
 
 
 
48
 
49
  # return model
50
 
@@ -75,4 +89,5 @@ if __name__ == '__main__':
75
  model = ImageCaptioningModel()
76
  optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
77
  loss_fn = torch.nn.CrossEntropyLoss()
 
78
  train_model(model, dataloader, optimizer, loss_fn)
 
7
  from config.config import Config
8
  from models.model import ImageCaptioningModel
9
 
10
+ import mlflow
11
+ import mlflow.pytorch
12
 
13
+ # TODO : Implementing Weights and Biases to for project tracking and evaluation and TODO : DVC also for data versioning
14
 
15
 
16
  def train_model(model,dataLoader, optimizer, loss_fn):
17
 
18
+ with mlflow.start_run():
19
+ mlflow.log_params({
20
+ "epochs": Config.EPOCHS,
21
+ "batch_size": Config.BATCH_SIZE,
22
+ "learning_rate": Config.LEARNING_RATE,
23
+ "device": Config.DEVICE
24
+ })
25
+
26
  model.gpt2_model.train()
27
  for epoch in range(Config.EPOCHS):
28
  epoch_loss = 0
 
51
 
52
  epoch_loss += loss.item()
53
 
54
+ print(f'Epoch {epoch + 1}, Loss: {epoch_loss/len(dataLoader):.4f}')
55
+ mlflow.log_metric('loss', epoch_loss/len(dataLoader), step=epoch)
56
 
57
  # Save the model
58
  model.save('model')
59
+ # save the artifacts
60
+ mlflow.log_artifacts('model')
61
+ mlflow.pytorch.log_model(model.gpt2_model, "models")
62
 
63
  # return model
64
 
 
89
  model = ImageCaptioningModel()
90
  optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
91
  loss_fn = torch.nn.CrossEntropyLoss()
92
+ mlflow.set_experiment('ImageCaptioning')
93
  train_model(model, dataloader, optimizer, loss_fn)