Unggi commited on
Commit
3f7c440
ยท
1 Parent(s): 72d0b90
Files changed (2) hide show
  1. app.py +53 -0
  2. kobart-model-essay.pth +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # make function using import pip to install torch
4
+ import pip
5
+ pip.main(['install', 'torch'])
6
+ pip.main(['install', 'transformers'])
7
+
8
+ import torch
9
+ import transformers
10
+
11
+
12
+ # saved_model
13
+ def load_model(model_path):
14
+ saved_data = torch.load(
15
+ model_path,
16
+ map_location="cpu"
17
+ )
18
+
19
+ bart_best = saved_data["model"]
20
+ train_config = saved_data["config"]
21
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
22
+
23
+ ## Load weights.
24
+ model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
25
+ model.load_state_dict(bart_best)
26
+
27
+ return model, tokenizer
28
+
29
+
30
+ # main
31
+ def inference(prompt):
32
+ model_path = "./kobart-model-essay.pth"
33
+
34
+ model, tokenizer = load_model(
35
+ model_path=model_path
36
+ )
37
+
38
+ input_ids = tokenizer.encode(prompt)
39
+ input_ids = torch.tensor(input_ids)
40
+ input_ids = input_ids.unsqueeze(0)
41
+ output = model.generate(input_ids)
42
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
43
+
44
+ return output
45
+
46
+
47
+ demo = gr.Interface(
48
+ fn=inference,
49
+ inputs="text",
50
+ outputs="text" #return ๊ฐ’
51
+ ).launch() # launch(share=True)๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€์—์„œ ์ ‘์† ๊ฐ€๋Šฅํ•œ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋จ
52
+
53
+ demo.launch()
kobart-model-essay.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdbc4998b5457e983a1d56a3e45f93f034e47a9f72310371f412be2d6bf4f880
3
+ size 496661433