File size: 3,694 Bytes
f8bd4d2 a690efe f8bd4d2 a690efe f8bd4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import os
from loguru import logger
import torch
from src.model import TransformerModel
from src.tokenizer import Tokenizer
from src.eval import evaluate
from src.config import Config
from gradio.themes import Default, Soft
model_folder_path = './models'
model_path = ''
def show_config():
try:
# 假设config文件名为config.txt,你需要根据实际情况修改文件名
with open('./config.yaml', 'r', encoding='utf-8') as file:
content = file.read()
return content
except FileNotFoundError:
return "未找到config文件,请检查文件路径。"
def get_model_files():
if not os.path.exists(model_folder_path):
return []
return [f for f in os.listdir(model_folder_path) if os.path.isfile(os.path.join(model_folder_path, f))]
def greet(name):
return f"你好, {name}!"
def load_model(selected_model, model):
global model_path
model_path = os.path.join(model_folder_path, selected_model)
logger.info(f'选择模型路径为{model_path}')
state_dict = torch.load(model_path,map_location = torch.device('cpu'))
logger.info(f"加载的state_dict大小: {sum([param.numel() for param in state_dict.values()])}")
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)
# 定义一个包装函数,只返回第二个变量
def wrapper_translate(input_text):
if config.config['NO_UNK'] == 0 and config.config['BEAM_SEARCH'] == 0:
_, second_variable = model.translate(input_text)
elif config.config['NO_UNK'] == 0 and config.config['BEAM_SEARCH'] == 1:
_, second_variable = model.translate_beam_search(input_text)
elif config.config['NO_UNK'] == 1 and config.config['BEAM_SEARCH'] == 0:
_, second_variable = model.translate_no_unk(input_text)
elif config.config['NO_UNK'] == 1 and config.config['BEAM_SEARCH'] == 1:
_, second_variable = model.translate_no_unk_beam_search(input_text)
logger.info(second_variable)
return second_variable
custom_theme = Soft()
with gr.Blocks(title="南京大学智科NLP小作业-Transformer翻译系统", theme=custom_theme) as demo:
en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json')
cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json')
config = Config('./config.yaml')
DEVICE = torch.device(config.config['DEVICE'])
model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer)
gr.Markdown("# 欢迎试用Demo!!!")
with gr.Row():
config_display = gr.Textbox(label="config content", lines=12)
with gr.Row():
model_files = get_model_files() # 假设 get_model_files 函数用于获取模型文件列表
logger.info(model_files)
model_selector = gr.Dropdown(
choices=model_files,
label="选择模型",
#value=None if not model_files else model_files[0] # 可根据需要设置默认值
)
with gr.Row():
name_input = gr.Textbox(label="输入英文")
output_text = gr.Textbox(label="中文翻译")
with gr.Row():
greet_btn = gr.Button("翻译")
greet_btn.click(
fn=wrapper_translate,
inputs=name_input,
outputs=output_text
)
demo.load(
fn=show_config,
inputs=None,
outputs=config_display
)
model_selector.change(
fn=lambda selected_model: load_model(selected_model, model),
inputs=model_selector,
outputs=None
)
if __name__ == "__main__":
demo.launch() |