File size: 1,106 Bytes
24b4b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c1f4f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
from pathlib import Path
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer ,GPT2Model

current_path = os.path.dirname(os.path.abspath(__file__))
tokenizer_path = os.path.join(current_path, "gpt_tokenizer")
model_path = os.path.join(current_path, "gpt2_3epoch")
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) # also try gpt2-medium
model = GPT2LMHeadModel.from_pretrained(model_path)
extra_tokens = ["<email>", "<subject>"]
tokenizer.add_tokens(extra_tokens)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

def subject_gen_func(email):
    device = "cpu"
    prompt = email
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids)
    pad_token_id = tokenizer.eos_token_id
    output_ids = model.generate(input_ids, max_length=1024, num_return_sequences=1,attention_mask=attention_mask,
            pad_token_id=pad_token_id)


    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text