Isaoudata commited on
Commit
ff172a8
·
1 Parent(s): bec0361

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -33
main.py DELETED
@@ -1,33 +0,0 @@
1
- from transformers import GPT2Tokenizer
2
- import torch
3
- import pickle
4
-
5
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
6
- tokenizer.pad_token = tokenizer.eos_token
7
- #model = GPT2LMHeadModel.from_pretrained('gpt2')
8
- #model.load_state_dict(torch.load("/kaggle/input/poem-pt/poem_model.pt"))
9
- #torch.save(model,"poem_model.pt")
10
- model = torch.load("poem_model.pt")
11
-
12
- def infer(inp):
13
- inp = tokenizer(inp,return_tensors="pt")
14
- X = inp["input_ids"] #.to(device)
15
- a = inp["attention_mask"] #.to(device)
16
- output = model.generate(X,
17
- attention_mask=a,
18
- max_length=100,
19
- min_length=10,
20
- early_stopping=True,
21
- num_beams=5,
22
- no_repeat_ngram_size=2)
23
-
24
- output = tokenizer.decode(output[0])
25
-
26
- return output
27
-
28
- #pickle.dump(model,open('poem_model.pt','wb'))
29
-
30
- output = infer(" I shall go")
31
-
32
-
33
- print(output)