Isaoudata commited on
Commit
7fdb064
·
1 Parent(s): 40b3bdd

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +33 -0
  2. poem_model.pt +3 -0
main.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer
2
+ import torch
3
+ import pickle
4
+
5
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
6
+ tokenizer.pad_token = tokenizer.eos_token
7
+ #model = GPT2LMHeadModel.from_pretrained('gpt2')
8
+ #model.load_state_dict(torch.load("/kaggle/input/poem-pt/poem_model.pt"))
9
+ #torch.save(model,"poem_model.pt")
10
+ model = torch.load("poem_model.pt")
11
+
12
+ def infer(inp):
13
+ inp = tokenizer(inp,return_tensors="pt")
14
+ X = inp["input_ids"] #.to(device)
15
+ a = inp["attention_mask"] #.to(device)
16
+ output = model.generate(X,
17
+ attention_mask=a,
18
+ max_length=100,
19
+ min_length=10,
20
+ early_stopping=True,
21
+ num_beams=5,
22
+ no_repeat_ngram_size=2)
23
+
24
+ output = tokenizer.decode(output[0])
25
+
26
+ return output
27
+
28
+ #pickle.dump(model,open('poem_model.pt','wb'))
29
+
30
+ output = infer(" I shall go")
31
+
32
+
33
+ print(output)
poem_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0d6b10f3fe4fba2c6a28894281e49ed41991ca371dcb26ae84b4485119561e1
3
+ size 510424627