PatentMatch / infer.py
DataRaptor's picture
Upload 5 files
152844c
raw
history blame
4.24 kB
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
from torch.utils.data import Dataset
class MeanPooling(nn.Module):
def __init__(self):
super(MeanPooling, self).__init__()
def forward(self, last_hidden_state, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
mean_embeddings = sum_embeddings / sum_mask
return mean_embeddings
class MeanPoolingLayer(nn.Module):
def __init__(self, input_size, target_size):
super(MeanPoolingLayer, self).__init__()
self.pool = MeanPooling()
self.fc = nn.Linear(input_size, target_size)
def forward(self, inputs, mask):
last_hidden_states = inputs[0]
feature = self.pool(last_hidden_states, mask)
outputs = self.fc(feature)
return outputs
def weight_init_normal(module, model):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class USPPPMModel(nn.Module):
def __init__(self, backbone):
super(USPPPMModel, self).__init__()
self.config = AutoConfig.from_pretrained(backbone, output_hidden_states=True)
self.model = AutoModel.from_pretrained(backbone, config=self.config)
self.head = MeanPoolingLayer(768,1)
self.tokenizer = AutoTokenizer.from_pretrained(backbone);
# sectoks = ['[CTG]', '[CTX]', '[ANC]', '[TGT]']
# self.tokenizer.add_special_tokens({'additional_special_tokens': sectoks})
# self.model.resize_token_embeddings(len(self.tokenizer))
def _init_weights(self, layer):
for module in layer.modules():
init_fn = weight_init_normal
init_fn(module, self)
# print(type(module))
def forward(self, inputs):
outputs = self.model(**inputs)
outputs = self.head(outputs, inputs['attention_mask'])
return outputs
table = """
A: Human Necessities
B: Operations and Transport
C: Chemistry and Metallurgy
D: Textiles
E: Fixed Constructions
F: Mechanical Engineering
G: Physics
H: Electricity
Y: Emerging Cross-Sectional Technologies
"""
splits = [i for i in table.split('\n') if i != '']
table = {e.split(': ')[0]: e.split(': ')[1] for e in splits}
class USPPPMDataset(Dataset):
def __init__(self, tokenizer, max_length):
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self): return 0
def __getitem__(self, x):
score = x['label']
sep = '' + self.tokenizer.sep_token + ''
s = x['anchor'] + sep + x['target'] + sep + x['title']
inputs = self.tokenizer(
s, add_special_tokens=True,
max_length=self.max_length, padding='max_length',
truncation=True,
return_offsets_mapping=False
)
for k, v in inputs.items(): inputs[k] = torch.tensor(v, dtype=torch.long).unsqueeze(dim=0)
label = torch.tensor(score, dtype=torch.float)
return inputs, label
if __name__ == '__main__':
model = USPPPMModel('microsoft/deberta-v3-small')
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
model.eval()
ds = USPPPMDataset(model.tokenizer, 133)
d = {
'anchor': 'sprayed',
'target': 'thermal sprayed coating',
'title': 'building',
'label': 0
}
inp = ds[d]
x = inp[0]
with torch.no_grad():
y = model(x)
print('y:', y)