|
|
|
import torch |
|
import os |
|
from torch.nn import init |
|
import cv2 |
|
import numpy as np |
|
import time |
|
import requests |
|
|
|
from IndicPhotoOCR.detection import east_config as cfg |
|
from IndicPhotoOCR.detection import east_preprossing as preprossing |
|
from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms |
|
|
|
|
|
|
|
|
|
model_info = { |
|
"east": { |
|
"paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path], |
|
"urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"] |
|
}, |
|
} |
|
|
|
class ModelManager: |
|
def __init__(self): |
|
|
|
pass |
|
|
|
def download_model(self, url, path): |
|
response = requests.get(url, stream=True) |
|
if response.status_code == 200: |
|
with open(path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
print(f"Downloaded: {path}") |
|
else: |
|
print(f"Failed to download from {url}") |
|
|
|
def ensure_model(self, model_name): |
|
model_paths = model_info[model_name]["paths"] |
|
urls = model_info[model_name]["urls"] |
|
|
|
|
|
for model_path, url in zip(model_paths, urls): |
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True) |
|
|
|
if not os.path.exists(model_path): |
|
print(f"Model not found locally. Downloading {model_name} from {url}...") |
|
self.download_model(url, model_path) |
|
else: |
|
print(f"Model already exists at {model_path}. No need to download.") |
|
|
|
|
|
|
|
|
|
model_manager = ModelManager() |
|
model_manager.ensure_model("east") |
|
|
|
|
|
|
|
def init_weights(m_list, init_type=cfg.init_type, gain=0.02): |
|
print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type)) |
|
|
|
for m in m_list: |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=gain) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=gain) |
|
else: |
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
|
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.normal_(m.weight.data, 1.0, gain) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type)) |
|
|
|
|
|
def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'): |
|
"""[summary] |
|
[description] |
|
Arguments: |
|
state {[type]} -- [description] a dict describe some params |
|
Keyword Arguments: |
|
filename {str} -- [description] (default: {'checkpoint.pth.tar'}) |
|
""" |
|
weightpath = os.path.abspath(cfg.checkpoint) |
|
print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath)) |
|
checkpoint = torch.load(weightpath) |
|
start_epoch = checkpoint['epoch'] + 1 |
|
model.load_state_dict(checkpoint['state_dict']) |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
scheduler.load_state_dict(checkpoint['scheduler']) |
|
print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath)) |
|
|
|
return start_epoch |
|
|
|
|
|
def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'): |
|
"""[summary] |
|
[description] |
|
Arguments: |
|
state {[type]} -- [description] a dict describe some params |
|
Keyword Arguments: |
|
filename {str} -- [description] (default: {'checkpoint.pth.tar'}) |
|
""" |
|
print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch)) |
|
state = { |
|
'epoch': epoch, |
|
'state_dict': model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'scheduler': scheduler.state_dict() |
|
} |
|
weight_dir = cfg.save_model_path |
|
if not os.path.exists(weight_dir): |
|
os.mkdir(weight_dir) |
|
filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar' |
|
file_path = os.path.join(weight_dir, filename) |
|
torch.save(state, file_path) |
|
print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch)) |
|
|
|
|
|
class Regularization(torch.nn.Module): |
|
def __init__(self, model, weight_decay, p=2): |
|
super(Regularization, self).__init__() |
|
if weight_decay < 0: |
|
print("param weight_decay can not <0") |
|
exit(0) |
|
self.model = model |
|
self.weight_decay = weight_decay |
|
self.p = p |
|
self.weight_list = self.get_weight(model) |
|
|
|
|
|
def to(self, device): |
|
self.device = device |
|
super().to(device) |
|
return self |
|
|
|
def forward(self, model): |
|
self.weight_list = self.get_weight(model) |
|
reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p) |
|
return reg_loss |
|
|
|
def get_weight(self, model): |
|
weight_list = [] |
|
for name, param in model.named_parameters(): |
|
if 'weight' in name: |
|
weight = (name, param) |
|
weight_list.append(weight) |
|
return weight_list |
|
|
|
def regularization_loss(self, weight_list, weight_decay, p=2): |
|
reg_loss = 0 |
|
for name, w in weight_list: |
|
l2_reg = torch.norm(w, p=p) |
|
reg_loss = reg_loss + l2_reg |
|
|
|
reg_loss = weight_decay * reg_loss |
|
return reg_loss |
|
|
|
def weight_info(self, weight_list): |
|
print("---------------regularization weight---------------") |
|
for name, w in weight_list: |
|
print(name) |
|
print("---------------------------------------------------") |
|
|
|
|
|
def resize_image(im, max_side_len=2400): |
|
''' |
|
resize image to a size multiple of 32 which is required by the network |
|
:param im: the resized image |
|
:param max_side_len: limit of max image size to avoid out of memory in gpu |
|
:return: the resized image and the resize ratio |
|
''' |
|
h, w, _ = im.shape |
|
|
|
resize_w = w |
|
resize_h = h |
|
|
|
|
|
""" |
|
if max(resize_h, resize_w) > max_side_len: |
|
ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w |
|
else: |
|
ratio = 1. |
|
|
|
resize_h = int(resize_h * ratio) |
|
resize_w = int(resize_w * ratio) |
|
""" |
|
|
|
resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 |
|
resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 |
|
|
|
im = cv2.resize(im, (int(resize_w), int(resize_h))) |
|
|
|
ratio_h = resize_h / float(h) |
|
ratio_w = resize_w / float(w) |
|
|
|
return im, (ratio_h, ratio_w) |
|
|
|
|
|
def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2): |
|
''' |
|
restore text boxes from score map and geo map |
|
:param score_map: |
|
:param geo_map: |
|
:param timer: |
|
:param score_map_thresh: threshhold for score map |
|
:param box_thresh: threshhold for boxes |
|
:param nms_thres: threshold for nms |
|
:return: |
|
''' |
|
|
|
|
|
if len(score_map.shape) == 4: |
|
score_map = score_map[0, :, :, 0] |
|
geo_map = geo_map[0, :, :, :] |
|
|
|
xy_text = np.argwhere(score_map > score_map_thresh) |
|
|
|
xy_text = xy_text[np.argsort(xy_text[:, 0])] |
|
|
|
start = time.time() |
|
text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4, |
|
geo_map[xy_text[:, 0], xy_text[:, 1], :]) |
|
|
|
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) |
|
boxes[:, :8] = text_box_restored.reshape((-1, 8)) |
|
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] |
|
timer['restore'] = time.time() - start |
|
|
|
start = time.time() |
|
boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres) |
|
timer['nms'] = time.time() - start |
|
|
|
if boxes.shape[0] == 0: |
|
return None, timer |
|
|
|
|
|
for i, box in enumerate(boxes): |
|
mask = np.zeros_like(score_map, dtype=np.uint8) |
|
cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) |
|
boxes[i, 8] = cv2.mean(score_map, mask)[0] |
|
boxes = boxes[boxes[:, 8] > box_thresh] |
|
return boxes, timer |
|
|
|
|
|
def sort_poly(p): |
|
min_axis = np.argmin(np.sum(p, axis=1)) |
|
p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] |
|
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): |
|
return p |
|
else: |
|
return p[[0, 3, 2, 1]] |
|
|
|
|
|
def mean_image_subtraction(images, means=cfg.means): |
|
''' |
|
image normalization |
|
:param images: bs * w * h * channel |
|
:param means: |
|
:return: |
|
''' |
|
num_channels = images.data.shape[1] |
|
if len(means) != num_channels: |
|
raise ValueError('len(means) must match the number of channels') |
|
for i in range(num_channels): |
|
images.data[:, i, :, :] -= means[i] |
|
|
|
return images |
|
|