Unggi commited on
Commit
3b59010
Β·
1 Parent(s): f8c52e0
Files changed (2) hide show
  1. app.py +50 -4
  2. kobart-model-summary.pth +3 -0
app.py CHANGED
@@ -1,7 +1,53 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ # make function using import pip to install torch
4
+ import pip
5
+ pip.main(['install', 'torch'])
6
+ pip.main(['install', 'transformers'])
7
 
8
+ import torch
9
+ import transformers
10
+
11
+
12
+ # saved_model
13
+ def load_model(model_path):
14
+ saved_data = torch.load(
15
+ model_path,
16
+ map_location="cpu"
17
+ )
18
+
19
+ bart_best = saved_data["model"]
20
+ train_config = saved_data["config"]
21
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
22
+
23
+ ## Load weights.
24
+ model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
25
+ model.load_state_dict(bart_best)
26
+
27
+ return model, tokenizer
28
+
29
+
30
+ # main
31
+ def inference(prompt):
32
+ model_path = "./kobart-model-summary.pth"
33
+
34
+ model, tokenizer = load_model(
35
+ model_path=model_path
36
+ )
37
+
38
+ input_ids = tokenizer.encode(prompt)
39
+ input_ids = torch.tensor(input_ids)
40
+ input_ids = input_ids.unsqueeze(0)
41
+ output = model.generate(input_ids)
42
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
43
+
44
+ return output
45
+
46
+
47
+ demo = gr.Interface(
48
+ fn=inference,
49
+ inputs="text",
50
+ outputs="text" #return κ°’
51
+ ).launch() # launch(share=True)λ₯Ό μ„€μ •ν•˜λ©΄ μ™ΈλΆ€μ—μ„œ 접속 κ°€λŠ₯ν•œ 링크가 생성됨
52
+
53
+ demo.launch()
kobart-model-summary.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:533e88edefe7f9621fb13ac6df3e7b288afe876218c8bf9468900f04014a8857
3
+ size 496661433