# model.py import torch import torch.nn as nn from transformers import AutoModelForSeq2SeqLM from torchvision import models class ImageToTextProjector(nn.Module): def __init__(self, image_embedding_dim, text_embedding_dim): super(ImageToTextProjector, self).__init__() self.fc = nn.Linear(image_embedding_dim, text_embedding_dim) self.activation = nn.ReLU() self.dropout = nn.Dropout(p=0.5) def forward(self, x): x = self.fc(x) x = self.activation(x) x = self.dropout(x) return x class CombinedModel(nn.Module): def __init__(self, video_model, report_generator, num_classes, projector): super(CombinedModel, self).__init__() self.video_model = video_model self.report_generator = report_generator self.classifier = nn.Linear(512, num_classes) self.projector = projector self.dropout = nn.Dropout(p=0.5) def forward(self, images, labels=None): video_embeddings = self.video_model(images) video_embeddings = self.dropout(video_embeddings) class_outputs = self.classifier(video_embeddings) projected_embeddings = self.projector(video_embeddings) encoder_inputs = projected_embeddings.unsqueeze(1) if labels is not None: outputs = self.report_generator( inputs_embeds=encoder_inputs, labels=labels ) gen_loss = outputs.loss generated_report = None else: generated_report_ids = self.report_generator.generate( inputs_embeds=encoder_inputs, max_length=512, num_beams=4, early_stopping=True ) generated_report = report_generator_tokenizer.batch_decode( generated_report_ids, skip_special_tokens=True ) gen_loss = None return class_outputs, generated_report, gen_loss