|
import gradio as gr |
|
import os |
|
from loguru import logger |
|
import torch |
|
from src.model import TransformerModel |
|
from src.tokenizer import Tokenizer |
|
from src.eval import evaluate |
|
from src.config import Config |
|
from gradio.themes import Default, Soft |
|
|
|
model_folder_path = './models' |
|
model_path = '' |
|
|
|
|
|
def show_config(): |
|
try: |
|
|
|
with open('./config.yaml', 'r', encoding='utf-8') as file: |
|
content = file.read() |
|
return content |
|
except FileNotFoundError: |
|
return "未找到config文件,请检查文件路径。" |
|
|
|
def get_model_files(): |
|
if not os.path.exists(model_folder_path): |
|
return [] |
|
return [f for f in os.listdir(model_folder_path) if os.path.isfile(os.path.join(model_folder_path, f))] |
|
|
|
def greet(name): |
|
return f"你好, {name}!" |
|
|
|
def load_model(selected_model, model): |
|
global model_path |
|
model_path = os.path.join(model_folder_path, selected_model) |
|
logger.info(f'选择模型路径为{model_path}') |
|
state_dict = torch.load(model_path,map_location = torch.device('cpu')) |
|
logger.info(f"加载的state_dict大小: {sum([param.numel() for param in state_dict.values()])}") |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
model.to(DEVICE) |
|
|
|
|
|
def wrapper_translate(input_text): |
|
if config.config['NO_UNK'] == 0 and config.config['BEAM_SEARCH'] == 0: |
|
_, second_variable = model.translate(input_text) |
|
elif config.config['NO_UNK'] == 0 and config.config['BEAM_SEARCH'] == 1: |
|
_, second_variable = model.translate_beam_search(input_text) |
|
elif config.config['NO_UNK'] == 1 and config.config['BEAM_SEARCH'] == 0: |
|
_, second_variable = model.translate_no_unk(input_text) |
|
elif config.config['NO_UNK'] == 1 and config.config['BEAM_SEARCH'] == 1: |
|
_, second_variable = model.translate_no_unk_beam_search(input_text) |
|
logger.info(second_variable) |
|
return second_variable |
|
|
|
custom_theme = Soft() |
|
|
|
with gr.Blocks(title="南京大学智科NLP小作业-Transformer翻译系统", theme=custom_theme) as demo: |
|
en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json') |
|
cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json') |
|
config = Config('./config.yaml') |
|
DEVICE = torch.device(config.config['DEVICE']) |
|
model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer) |
|
|
|
gr.Markdown("# 欢迎试用Demo!!!") |
|
with gr.Row(): |
|
config_display = gr.Textbox(label="config content", lines=12) |
|
with gr.Row(): |
|
model_files = get_model_files() |
|
logger.info(model_files) |
|
model_selector = gr.Dropdown( |
|
choices=model_files, |
|
label="选择模型", |
|
|
|
) |
|
with gr.Row(): |
|
name_input = gr.Textbox(label="输入英文") |
|
output_text = gr.Textbox(label="中文翻译") |
|
with gr.Row(): |
|
greet_btn = gr.Button("翻译") |
|
|
|
greet_btn.click( |
|
fn=wrapper_translate, |
|
inputs=name_input, |
|
outputs=output_text |
|
) |
|
|
|
demo.load( |
|
fn=show_config, |
|
inputs=None, |
|
outputs=config_display |
|
) |
|
model_selector.change( |
|
fn=lambda selected_model: load_model(selected_model, model), |
|
inputs=model_selector, |
|
outputs=None |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |