minnehwg commited on
Commit
49f60a6
·
verified ·
1 Parent(s): c8c6527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -37
app.py CHANGED
@@ -1,37 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- import streamlit as st
3
- import torch
4
-
5
- @st.cache_resource
6
- def load_model(cp_path):
7
- model = AutoModelForSeq2SeqLM.from_pretrained(cp_path)
8
- return model
9
-
10
- @st.cache_resource
11
- def load_tokenizer(path):
12
- tokenizer = AutoTokenizer.from_pretrained(path)
13
-
14
- cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2'
15
- cp_org = 'minnehwg/finetune-newwiki-summarization-ver2'
16
-
17
- model_org = load_model(cp_org)
18
- model_aug = AutoModelForSeq2SeqLM.from_pretrained(cp_aug)
19
- tokenizer = load_tokenizer("VietAI/vit5-base")
20
-
21
- def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
22
- model.eval()
23
- model.to(device)
24
- inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
25
- with torch.no_grad():
26
- summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
27
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
28
- return summary
29
-
30
- if text:
31
- re1 = summarize(model_org, tokenizer, text)
32
- re2 = summarize(model_aug, tokenizer, text)
33
- out = {
34
- 'Result from model with original data': re1,
35
- 'Result from model with augmented data': re2
36
- }
37
- st.json(out)