Spaces:
Build error
Build error
""" | |
Copyright (c) 2019-present NAVER Corp. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
import torch.nn as nn | |
from modules_trba.transformation import TPS_SpatialTransformerNetwork | |
from modules_trba.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor | |
from modules_trba.sequence_modeling import BidirectionalLSTM | |
from modules_trba.prediction import Attention | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import random | |
import copy | |
# from torch_edit_distance import levenshtein_distance | |
class STRScore(nn.Module): | |
def __init__(self, opt, converter, device, gtStr="", enableSingleCharAttrAve=False): | |
super(STRScore, self).__init__() | |
self.opt = opt | |
self.converter = converter | |
self.device = device | |
self.gtStr = gtStr | |
self.enableSingleCharAttrAve = enableSingleCharAttrAve | |
self.blank = torch.tensor([-1], dtype=torch.float).to(self.device) | |
self.separator = torch.tensor([-2], dtype=torch.float).to(self.device) | |
# singleChar - if >=0, then the output of STRScore will only be a single character | |
# instead of a whole. The char index will be equal to the parameter "singleChar". | |
def setSingleCharOutput(self, singleChar): | |
self.singleChar = singleChar | |
def forward(self, preds): | |
bs = preds.shape[0] | |
# text_for_loss, length_for_loss = self.converter.encode(labels, batch_max_length=self.opt.batch_max_length) | |
text_for_loss_length = self.opt.batch_max_length + 1 | |
length_for_pred = torch.IntTensor([self.opt.batch_max_length] * bs).to(self.device) | |
if 'CTC' in self.opt.Prediction: | |
# Calculate evaluation loss for CTC decoder. | |
preds_size = torch.FloatTensor([preds.size(1)] * bs) | |
if self.opt.baiduCTC: | |
_, preds_index = preds.max(2) | |
preds_index = preds_index.view(-1) | |
else: | |
_, preds_index = preds.max(2) | |
# print("preds_index shape: ", preds_index.shape) | |
preds_str = self.converter.decode(preds_index.data, preds_size.data) | |
# preds_str = self.converter.decode(preds_index, length_for_pred) | |
preds = preds.log_softmax(2).permute(1, 0, 2) | |
else: | |
preds = preds[:, :text_for_loss_length, :] | |
# select max probabilty (greedy decoding) then decode index to character | |
_, preds_index = preds.max(2) | |
# print("preds shape: ", preds.shape) | |
# print("preds_index: ", preds_index) | |
preds_str = self.converter.decode(preds_index, length_for_pred) | |
# print("preds_str: ", preds_str) | |
# Confidence score | |
# ARGMAX calculation | |
sum = torch.FloatTensor([0]*bs).to(self.device) | |
if self.enableSingleCharAttrAve: | |
sum = torch.zeros((bs, preds.shape[2])).to(self.device) | |
if self.opt.confidence_mode == 0: | |
preds_prob = F.softmax(preds, dim=2) | |
# print("preds_prob shape: ", preds_prob.shape) | |
preds_max_prob, _ = preds_prob.max(dim=2) | |
# print("preds_max_prob shape: ", preds_max_prob.shape) | |
confidence_score_list = [] | |
count = 0 | |
for one_hot_preds, pred, pred_max_prob in zip(preds_prob, preds_str, preds_max_prob): | |
if 'Attn' in self.opt.Prediction: | |
if self.enableSingleCharAttrAve: | |
one_hot = one_hot_preds[self.singleChar, :] | |
sum[count] = one_hot | |
else: | |
pred_EOS = pred.find('[s]') | |
pred = pred[:pred_EOS] | |
pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters | |
# pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only | |
if pred_max_prob.shape[0] == 0: continue | |
if self.opt.scorer == "cumprod": | |
confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 | |
elif self.opt.scorer == "mean": | |
confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 | |
sum[count] += confidence_score | |
sum = sum.unsqueeze(1) | |
elif 'CTC' in self.opt.Prediction: | |
if self.opt.scorer == "cumprod": | |
confidence_score = pred_max_prob.cumprod(dim=0)[-1] ### Maximum is 1 | |
elif self.opt.scorer == "mean": | |
confidence_score = torch.mean(pred_max_prob) ### Maximum is 1 | |
sum[count] += confidence_score | |
sum = sum.unsqueeze(1) | |
count += 1 | |
# return sum.detach().cpu().numpy() | |
# print("sumshape: ", sum.shape) | |
elif self.opt.confidence_mode == 1: | |
preds_prob = F.softmax(preds, dim=2) | |
### Predicted indices | |
preds_max_prob = torch.argmax(preds_prob, 2) | |
# print("preds_max_prob shape: ", preds_max_prob.shape) | |
### Ground truth indices | |
gtIndices, _ = self.converter.encode([self.gtStr for i in range(0,preds_prob.shape[0])], batch_max_length=self.opt.batch_max_length-1) | |
# print("gtIndices shape: ", gtIndices.shape) | |
### Acquire levenstein distance | |
m = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) | |
n = torch.tensor([preds_prob.shape[1] for i in range(0, gtIndices.shape[0])], dtype=torch.float).to(self.device) | |
# print("m: ", m) | |
# print("preds_max_prob dtype: ", preds_max_prob.dtype) | |
# print("gtIndices dtype: ", gtIndices.dtype) | |
preds_max_prob = preds_max_prob.type(torch.float) | |
gtIndices = gtIndices.type(torch.float) | |
r = levenshtein_distance(preds_max_prob.to(self.device), gtIndices.to(self.device), n, m, torch.cat([self.blank, self.separator]), torch.empty([], dtype=torch.float).to(self.device)) | |
# print("r shape: ", r.shape) | |
# confidence_score_list = [] | |
# count = 0 | |
# for pred, pred_max_prob in zip(preds_str, preds_max_prob): | |
# if 'Attn' in self.opt.Prediction: | |
# pred_EOS = pred.find('[s]') | |
# pred = pred[:pred_EOS] | |
# pred_max_prob = pred_max_prob[:pred_EOS] ### Use score of all letters | |
# # pred_max_prob = pred_max_prob[0:1] ### Use score of first letter only | |
# if pred_max_prob.shape[0] == 0: continue | |
# confidence_score = pred_max_prob.cumprod(dim=0)[-1] | |
# sum[count] += confidence_score | |
# count += 1 | |
# return sum.detach().cpu().numpy() | |
# print("sumshape: ", sum.shape) | |
# sum = sum.unsqueeze(1) | |
rSoft = F.softmax(r[:,2].type(torch.float)) | |
# rSoft = rSoft.contiguous() | |
rNorm = rSoft.max()-rSoft | |
sum = rNorm.unsqueeze(1) | |
print("sum shape: ", sum.shape) | |
return sum | |
class SuperPixler(nn.Module): | |
def __init__(self, n_super_pixel, imageList, super_pixel_width, super_pixel_height, opt): | |
super(SuperPixler, self).__init__() | |
self.opt = opt | |
self.imageList = imageList | |
self.n_super_pixel = n_super_pixel | |
# self.image = image | |
# self.image = image.transpose(2, 0, 1) # model expects images in BRG, not RGB, so transpose color channels | |
# self.mean_color = self.image.mean() | |
# self.image = np.expand_dims(self.image, axis=0) | |
self.super_pixel_width = super_pixel_width | |
self.super_pixel_height = super_pixel_height | |
# def setImage(self, image): | |
# self.image = image | |
# self.image_height = image.shape[2] | |
# self.image_width = image.shape[3] | |
def sampleImages(self, size): | |
newImgList = [] | |
for i in range(0, size): | |
randIdx = random.randint(0, len(self.imageList)-1) | |
newImgList.append(copy.deepcopy(self.imageList[randIdx])) | |
return np.array(newImgList) | |
def forward(self, x): | |
""" | |
In the forward step we accept the super pixel masks and transform them to a batch of images | |
""" | |
# x = self.sampleMasks(image.shape[0]) | |
image = self.sampleImages(x.shape[0]) | |
self.image = image | |
self.image_height = image.shape[2] | |
self.image_width = image.shape[3] | |
self.mean_color = self.image.mean() | |
# self.mean_color = self.image.mean(axis=(1,2,3)) | |
# pixeled_image = np.repeat(self.image.copy(), x.shape[0], axis=0)# WARNING: | |
pixeled_image = self.image.copy() | |
# print("pixeled_image shape: ", pixeled_image.shape) | |
# print("x shape: ", x.shape) | |
for i, super_pixel in enumerate(x.T): | |
images_to_pixelate = [bool(p) for p in super_pixel] | |
# print("super_pixel shape: ", super_pixel.shape) | |
# print("images_to_pixelate len: ", len(images_to_pixelate)) | |
# print("i: {}, superPix: {}, images_to_pixelate: {}".format(i, super_pixel, images_to_pixelate)) | |
x = (i*self.super_pixel_height//self.image_height)*self.super_pixel_width | |
y = i*self.super_pixel_height%self.image_height | |
### Reshape image means since it has n-dim size, not a single number. Need to repeat sideways. | |
# origShapeToApply = pixeled_image[images_to_pixelate,:,y:y+self.super_pixel_height,x:x+self.super_pixel_width].shape | |
# print("origShapeToApply: ", origShapeToApply) | |
# mean_color_spec = np.tile(self.mean_color, origShapeToApply[1:]) # | |
# mean_color_spec = np.reshape(mean_color_spec, origShapeToApply[::-1]).T ### reshape to reversed | |
### Apply image means | |
pixeled_image[images_to_pixelate,:,y:y+self.super_pixel_height,x:x+self.super_pixel_width] = self.mean_color | |
return pixeled_image | |
class CastNumpy(nn.Module): | |
def __init__(self, device): | |
super(CastNumpy, self).__init__() | |
self.device = device | |
def forward(self, image): | |
""" | |
In the forward function we accept the inputs and cast them to a pytorch tensor | |
""" | |
image = np.ascontiguousarray(image) | |
image = torch.from_numpy(image).to(self.device) | |
if image.ndimension() == 3: | |
image = image.unsqueeze(0) | |
image_half = image.half() | |
return image_half.float() | |
class Model(nn.Module): | |
def __init__(self, opt, device, feature_ext_outputs=None): | |
super(Model, self).__init__() | |
self.opt = opt | |
self.device = device | |
self.gtText = None | |
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, | |
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} | |
""" Transformation """ | |
if opt.Transformation == 'TPS': | |
self.Transformation = TPS_SpatialTransformerNetwork( | |
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) | |
else: | |
print('No Transformation module specified') | |
""" FeatureExtraction """ | |
if opt.FeatureExtraction == 'VGG': | |
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) | |
elif opt.FeatureExtraction == 'RCNN': | |
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) | |
elif opt.FeatureExtraction == 'ResNet': | |
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) | |
else: | |
raise Exception('No FeatureExtraction module specified') | |
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 | |
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 | |
""" Sequence modeling""" | |
if opt.SequenceModeling == 'BiLSTM': | |
self.SequenceModeling = nn.Sequential( | |
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), | |
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) | |
self.SequenceModeling_output = opt.hidden_size | |
else: | |
print('No SequenceModeling module specified') | |
self.SequenceModeling_output = self.FeatureExtraction_output | |
""" Prediction """ | |
if opt.Prediction == 'CTC': | |
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) | |
elif opt.Prediction == 'Attn': | |
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) | |
else: | |
raise Exception('Prediction is neither CTC or Attn') | |
### Set feature map outputter modules | |
if opt.output_feat_maps: | |
feature_ext_outputs.set_feature_ext(self.FeatureExtraction) | |
### Define hooks | |
feature_ext_outputs = feature_ext_outputs | |
totalCNNLayers = 0 | |
idxToOutput = [] | |
layersList = [] | |
layerCount = 0 | |
# print("list(self.FeatureExtraction._modules.items()): ", list(self.FeatureExtraction._modules.items())) | |
# print("list(self.FeatureExtraction.ConvNet_modules.items())[0][1]: ", list(self.FeatureExtraction.ConvNet._modules.items())[0][1]) | |
first_layer = list(self.FeatureExtraction.ConvNet._modules.items())[0][1] | |
first_layer.register_backward_hook(feature_ext_outputs.append_first_grads) | |
for layer in self.FeatureExtraction.modules(): | |
if isinstance(layer, nn.Conv2d): | |
layerCount += 1 | |
if layerCount >= opt.min_layer_out and layerCount <= opt.max_layer_out: | |
layer.register_forward_hook(feature_ext_outputs.append_layer_out) | |
layer.register_backward_hook(feature_ext_outputs.append_grad_out) | |
# def get_feature_ext(self): | |
# return self.FeatureExtraction | |
def setGTText(self, text): | |
self.gtText = text | |
def forward(self, input, text="", is_train=True): | |
if self.opt.is_shap: | |
text = torch.LongTensor(input.shape[0], self.opt.batch_max_length + 1).fill_(0).to(self.device) | |
elif self.gtText is not None: | |
text = self.gtText | |
else: | |
text = torch.LongTensor(input.shape[0], self.opt.batch_max_length + 1).fill_(0).to(self.device) | |
# print("text shape: ", text.shape) (1,26) | |
tpsOut = input.contiguous() | |
""" Transformation stage """ | |
if not self.stages['Trans'] == "None": | |
tpsOut = self.Transformation(tpsOut) | |
# print("Transformation feature shape: ", input.shape) | |
""" Feature extraction stage """ | |
visual_feature = self.FeatureExtraction(tpsOut) | |
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] | |
visual_feature = visual_feature.squeeze(3) | |
# print("visual feature shape: ", visual_feature.shape) | |
""" Sequence modeling stage """ | |
if self.stages['Seq'] == 'BiLSTM': | |
contextual_feature = self.SequenceModeling(visual_feature) | |
else: | |
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM | |
# print("Sequence feature shape: ", contextual_feature.shape) | |
""" Prediction stage """ | |
if self.stages['Pred'] == 'CTC': | |
prediction = self.Prediction(contextual_feature.contiguous()) | |
else: | |
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) | |
# print("prediction feature shape: ", prediction.shape) | |
# return prediction, tpsOut | |
return prediction | |