shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
###################################################################
# File Name: GCN.py
# Author: S.X.Zhang
###################################################################
import torch
from torch import nn, Tensor
import numpy as np
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
class Positional_encoding(nn.Module):
def __init__(self, PE_size, n_position=256):
super(Positional_encoding, self).__init__()
self.PE_size = PE_size
self.n_position = n_position
self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
def get_encoding_table(self, n_position, PE_size):
position_table = np.array(
[[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
position_table[:, 0::2] = np.sin(position_table[:, 0::2])
position_table[:, 1::2] = np.cos(position_table[:, 1::2])
return torch.FloatTensor(position_table).unsqueeze(0)
def forward(self, inputs):
return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, embed_dim, dropout=0.1, if_resi=True):
super(MultiHeadAttention, self).__init__()
self.layer_norm = nn.LayerNorm(embed_dim)
self.MultiheadAttention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.Q_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
self.K_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
self.V_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
self.if_resi = if_resi
def forward(self, inputs):
query = self.layer_norm(inputs)
q = self.Q_proj(query)
k = self.K_proj(query)
v = self.V_proj(query)
attn_output, attn_output_weights = self.MultiheadAttention(q, k, v)
if self.if_resi:
attn_output += inputs
else:
attn_output = attn_output
return attn_output
class FeedForward(nn.Module):
def __init__(self, in_channel, FFN_channel, if_resi=True):
super(FeedForward, self).__init__()
"""
1024 2048
"""
output_channel = (FFN_channel, in_channel)
self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
self.fc2 = nn.Linear(output_channel[0], output_channel[1])
self.layer_norm = nn.LayerNorm(in_channel)
self.if_resi = if_resi
def forward(self, inputs):
outputs = self.layer_norm(inputs)
outputs = self.fc1(outputs)
outputs = self.fc2(outputs)
if self.if_resi:
outputs += inputs
else:
outputs = outputs
return outputs
class TransformerLayer(nn.Module):
def __init__(self, out_dim, in_dim, num_heads, attention_size,
dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
super(TransformerLayer, self).__init__()
self.block_nums = block_nums
self.if_resi = if_resi
self.linear = nn.Linear(in_dim, attention_size)
for i in range(self.block_nums):
self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, attention_size,
dropout=drop_rate, if_resi=if_resi))
self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
def forward(self, query):
inputs = self.linear(query)
# outputs = inputs
for i in range(self.block_nums):
outputs = self.__getattr__('MHA_self_%d' % i)(inputs)
outputs = self.__getattr__('FFN_%d' % i)(outputs)
if self.if_resi:
inputs = inputs+outputs
else:
inputs = outputs
# outputs = inputs
return inputs
class Transformer(nn.Module):
def __init__(self, in_dim, out_dim, num_heads=8,
dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
super().__init__()
self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
# self.pos_embedding = Positional_encoding(in_dim)
self.transformer = TransformerLayer(out_dim, in_dim, num_heads, attention_size=out_dim,
dim_feedforward=dim_feedforward, drop_rate=drop_rate,
if_resi=if_resi, block_nums=block_nums)
self.prediction = nn.Sequential(
nn.Conv1d(2*out_dim, 128, 1),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv1d(128, 64, 1),
nn.ReLU(inplace=True),
# nn.Dropout(0.1),
nn.Conv1d(64, 2, 1))
def forward(self, x, adj):
x = self.bn0(x)
x1 = x.permute(0, 2, 1)
# x1 = self.pos_embedding(x1)
x1 = self.transformer(x1)
x1 = x1.permute(0, 2, 1)
x = torch.cat([x1, self.conv1(x)], dim=1)
# x = x1+self.conv1(x)
pred = self.prediction(x)
return pred