Spaces:
Sleeping
Sleeping
| # model.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForSeq2SeqLM | |
| class ImageToTextProjector(nn.Module): | |
| def __init__(self, image_embedding_dim, text_embedding_dim): | |
| super(ImageToTextProjector, self).__init__() | |
| self.fc = nn.Linear(image_embedding_dim, text_embedding_dim) | |
| self.activation = nn.ReLU() | |
| self.dropout = nn.Dropout(p=0.5) | |
| def forward(self, x): | |
| x = self.fc(x) | |
| x = self.activation(x) | |
| x = self.dropout(x) | |
| return x | |
| class CombinedModel(nn.Module): | |
| def __init__(self, video_model, report_generator, num_classes, projector, tokenizer): | |
| super(CombinedModel, self).__init__() | |
| self.video_model = video_model | |
| self.report_generator = report_generator | |
| self.classifier = nn.Linear(512, num_classes) | |
| self.projector = projector | |
| self.dropout = nn.Dropout(p=0.5) | |
| self.tokenizer = tokenizer # Store tokenizer | |
| def forward(self, images, labels=None): | |
| video_embeddings = self.video_model(images) | |
| video_embeddings = self.dropout(video_embeddings) | |
| class_outputs = self.classifier(video_embeddings) | |
| projected_embeddings = self.projector(video_embeddings) | |
| encoder_inputs = projected_embeddings.unsqueeze(1) | |
| if labels is not None: | |
| outputs = self.report_generator( | |
| inputs_embeds=encoder_inputs, | |
| labels=labels | |
| ) | |
| gen_loss = outputs.loss | |
| generated_report = None | |
| else: | |
| generated_report_ids = self.report_generator.generate( | |
| inputs_embeds=encoder_inputs, | |
| max_length=512, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| generated_report = self.tokenizer.batch_decode( | |
| generated_report_ids, skip_special_tokens=True | |
| ) | |
| gen_loss = None | |
| return class_outputs, generated_report, gen_loss | |