Unggi's picture
Update app.py
f86827c
import gradio as gr
# make function using import pip to install torch
import pip
pip.main(['install', 'torch'])
pip.main(['install', 'transformers'])
import torch
import transformers
# saved_model
def load_model(model_path):
saved_data = torch.load(
model_path,
map_location="cpu"
)
bart_best = saved_data["model"]
train_config = saved_data["config"]
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
## Load weights.
model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
model.load_state_dict(bart_best)
return model, tokenizer
# main
def inference(prompt):
model_path = "./kobart-model-essay.pth"
model, tokenizer = load_model(
model_path=model_path
)
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.unsqueeze(0)
output = model.generate(input_ids)
output = tokenizer.decode(output[0], skip_special_tokens=True)
return output
demo = gr.Interface(
fn=inference,
inputs="text",
outputs="text", #return κ°’
examples=[
"꿈 μ†μ—μ„œ λ‚˜λŠ” λ§ˆλ²•μ˜ 숲으둜 λ– λ‚˜κ²Œ λ˜μ—ˆλ‹€. λ§ˆλ²•μ˜ μˆ²μ—μ„œ λ‚˜λŠ” λΉ—μžλ£¨λ₯Ό 타고 λ‚ μ•„λ‹€λ…”λ‹€. μˆ²μ„ λ‚ μ•„λ‹€λ‹ˆλŠ” 도쀑, λ‚˜λŠ” μ‹ λΉ„λ‘œμš΄ 성을 λ°œκ²¬ν•˜κ²Œ λ˜μ—ˆλ‹€. κ·Έ μ„± μ•ˆμ—λŠ” 무엇이 μžˆμ„κΉŒ? λ‚˜λŠ” κ·Έ μ„± μ•ˆμœΌλ‘œ λ“€μ–΄κ°”λ‹€. μ„± μ•ˆμ—λŠ” λ§ˆλ²•μ‚¬κ°€ μ‚΄κ³  μžˆμ—ˆλŠ”λ°, λ‚˜μ—κ²Œ λ§ˆλ²•μ„ κ°€λ₯΄μ³ μ£Όμ—ˆλ‹€. κ·Έ λ§ˆλ²•μœΌλ‘œ λ‚˜λŠ” λ‚΄κ°€ μ’‹μ•„ν•˜λŠ” μŒμ‹μ„ λ§Œλ“€μ–΄μ„œ 마음껏 λ¨Ήμ—ˆλ‹€!"
]
).launch() # launch(share=True)λ₯Ό μ„€μ •ν•˜λ©΄ μ™ΈλΆ€μ—μ„œ 접속 κ°€λŠ₯ν•œ 링크가 생성됨
demo.launch()