File size: 2,001 Bytes
747ba73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# model.py

import torch
import torch.nn as nn
from transformers import AutoModelForSeq2SeqLM

class ImageToTextProjector(nn.Module):
    def __init__(self, image_embedding_dim, text_embedding_dim):
        super(ImageToTextProjector, self).__init__()
        self.fc = nn.Linear(image_embedding_dim, text_embedding_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = self.fc(x)
        x = self.activation(x)
        x = self.dropout(x)
        return x

class CombinedModel(nn.Module):
    def __init__(self, video_model, report_generator, num_classes, projector, tokenizer):
        super(CombinedModel, self).__init__()
        self.video_model = video_model
        self.report_generator = report_generator
        self.classifier = nn.Linear(512, num_classes)
        self.projector = projector
        self.dropout = nn.Dropout(p=0.5)
        self.tokenizer = tokenizer  # Store tokenizer

    def forward(self, images, labels=None):
        video_embeddings = self.video_model(images)
        video_embeddings = self.dropout(video_embeddings)
        class_outputs = self.classifier(video_embeddings)
        projected_embeddings = self.projector(video_embeddings)
        encoder_inputs = projected_embeddings.unsqueeze(1)

        if labels is not None:
            outputs = self.report_generator(
                inputs_embeds=encoder_inputs,
                labels=labels
            )
            gen_loss = outputs.loss
            generated_report = None
        else:
            generated_report_ids = self.report_generator.generate(
                inputs_embeds=encoder_inputs,
                max_length=512,
                num_beams=4,
                early_stopping=True
            )
            generated_report = self.tokenizer.batch_decode(
                generated_report_ids, skip_special_tokens=True
            )
            gen_loss = None

        return class_outputs, generated_report, gen_loss