shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
import torch
import os
from torch.nn import init
import cv2
import numpy as np
import time
import requests
from IndicPhotoOCR.detection import east_config as cfg
from IndicPhotoOCR.detection import east_preprossing as preprossing
from IndicPhotoOCR.detection import east_locality_aware_nms as locality_aware_nms
# Example usage:
model_info = {
"east": {
"paths": [ cfg.checkpoint, cfg.pretrained_basemodel_path],
"urls" : ["https://github.com/anikde/STocr/releases/download/e0.1.0/epoch_990_checkpoint.pth.tar", "https://github.com/anikde/STocr/releases/download/e0.1.0/mobilenet_v2.pth.tar"]
},
}
class ModelManager:
def __init__(self):
# self.root_model_dir = "bharatOCR/detection/"
pass
def download_model(self, url, path):
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk: # Filter out keep-alive chunks
f.write(chunk)
print(f"Downloaded: {path}")
else:
print(f"Failed to download from {url}")
def ensure_model(self, model_name):
model_paths = model_info[model_name]["paths"] # Changed to handle multiple paths
urls = model_info[model_name]["urls"] # Changed to handle multiple URLs
for model_path, url in zip(model_paths, urls):
# full_model_path = os.path.join(self.root_model_dir, model_path)
# Ensure the model path directory exists
os.makedirs(os.path.dirname(os.path.join(*cfg.pretrained_basemodel_path.split("/"))), exist_ok=True)
if not os.path.exists(model_path):
print(f"Model not found locally. Downloading {model_name} from {url}...")
self.download_model(url, model_path)
else:
print(f"Model already exists at {model_path}. No need to download.")
# # Initialize ModelManager and ensure Hindi models are downloaded
model_manager = ModelManager()
model_manager.ensure_model("east")
def init_weights(m_list, init_type=cfg.init_type, gain=0.02):
print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type))
# this will apply to each layer
for m in m_list:
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type))
def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'):
"""[summary]
[description]
Arguments:
state {[type]} -- [description] a dict describe some params
Keyword Arguments:
filename {str} -- [description] (default: {'checkpoint.pth.tar'})
"""
weightpath = os.path.abspath(cfg.checkpoint)
print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath))
checkpoint = torch.load(weightpath)
start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath))
return start_epoch
def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'):
"""[summary]
[description]
Arguments:
state {[type]} -- [description] a dict describe some params
Keyword Arguments:
filename {str} -- [description] (default: {'checkpoint.pth.tar'})
"""
print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch))
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}
weight_dir = cfg.save_model_path
if not os.path.exists(weight_dir):
os.mkdir(weight_dir)
filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar'
file_path = os.path.join(weight_dir, filename)
torch.save(state, file_path)
print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch))
class Regularization(torch.nn.Module):
def __init__(self, model, weight_decay, p=2):
super(Regularization, self).__init__()
if weight_decay < 0:
print("param weight_decay can not <0")
exit(0)
self.model = model
self.weight_decay = weight_decay
self.p = p
self.weight_list = self.get_weight(model)
# self.weight_info(self.weight_list)
def to(self, device):
self.device = device
super().to(device)
return self
def forward(self, model):
self.weight_list = self.get_weight(model) # 获得最新的权重
reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
return reg_loss
def get_weight(self, model):
weight_list = []
for name, param in model.named_parameters():
if 'weight' in name:
weight = (name, param)
weight_list.append(weight)
return weight_list
def regularization_loss(self, weight_list, weight_decay, p=2):
reg_loss = 0
for name, w in weight_list:
l2_reg = torch.norm(w, p=p)
reg_loss = reg_loss + l2_reg
reg_loss = weight_decay * reg_loss
return reg_loss
def weight_info(self, weight_list):
print("---------------regularization weight---------------")
for name, w in weight_list:
print(name)
print("---------------------------------------------------")
def resize_image(im, max_side_len=2400):
'''
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
'''
h, w, _ = im.shape
resize_w = w
resize_h = h
# limit the max side
"""
if max(resize_h, resize_w) > max_side_len:
ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w
else:
ratio = 1.
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
"""
resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32
resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32
#resize_h, resize_w = 512, 512
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2):
'''
restore text boxes from score map and geo map
:param score_map:
:param geo_map:
:param timer:
:param score_map_thresh: threshhold for score map
:param box_thresh: threshhold for boxes
:param nms_thres: threshold for nms
:return:
'''
# score_map 和 geo_map 的维数进行调整
if len(score_map.shape) == 4:
score_map = score_map[0, :, :, 0]
geo_map = geo_map[0, :, :, :]
# filter the score map
xy_text = np.argwhere(score_map > score_map_thresh)
# sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 0])]
# restore
start = time.time()
text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4,
geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2
# print('{} text boxes before nms'.format(text_box_restored.shape[0]))
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
timer['restore'] = time.time() - start
# nms part
start = time.time()
boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres)
timer['nms'] = time.time() - start
# print(timer['nms'])
if boxes.shape[0] == 0:
return None, timer
# here we filter some low score boxes by the average score map, this is different from the orginal paper
for i, box in enumerate(boxes):
mask = np.zeros_like(score_map, dtype=np.uint8)
cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
boxes[i, 8] = cv2.mean(score_map, mask)[0]
boxes = boxes[boxes[:, 8] > box_thresh]
return boxes, timer
def sort_poly(p):
min_axis = np.argmin(np.sum(p, axis=1))
p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
return p
else:
return p[[0, 3, 2, 1]]
def mean_image_subtraction(images, means=cfg.means):
'''
image normalization
:param images: bs * w * h * channel
:param means:
:return:
'''
num_channels = images.data.shape[1]
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
for i in range(num_channels):
images.data[:, i, :, :] -= means[i]
return images