import gradio as gr import os from loguru import logger import torch from src.model import TransformerModel from src.tokenizer import Tokenizer from src.eval import evaluate from src.config import Config from gradio.themes import Default, Soft model_folder_path = './models' model_path = '' def show_config(): try: # 假设config文件名为config.txt,你需要根据实际情况修改文件名 with open('./config.yaml', 'r', encoding='utf-8') as file: content = file.read() return content except FileNotFoundError: return "未找到config文件,请检查文件路径。" def get_model_files(): if not os.path.exists(model_folder_path): return [] return [f for f in os.listdir(model_folder_path) if os.path.isfile(os.path.join(model_folder_path, f))] def greet(name): return f"你好, {name}!" def load_model(selected_model, model): global model_path model_path = os.path.join(model_folder_path, selected_model) logger.info(f'选择模型路径为{model_path}') state_dict = torch.load(model_path,map_location = torch.device('cpu')) logger.info(f"加载的state_dict大小: {sum([param.numel() for param in state_dict.values()])}") model.load_state_dict(state_dict) model.eval() model.to(DEVICE) # 定义一个包装函数,只返回第二个变量 def wrapper_translate(input_text): if config.config['NO_UNK'] == 0 and config.config['BEAM_SEARCH'] == 0: _, second_variable = model.translate(input_text) elif config.config['NO_UNK'] == 0 and config.config['BEAM_SEARCH'] == 1: _, second_variable = model.translate_beam_search(input_text) elif config.config['NO_UNK'] == 1 and config.config['BEAM_SEARCH'] == 0: _, second_variable = model.translate_no_unk(input_text) elif config.config['NO_UNK'] == 1 and config.config['BEAM_SEARCH'] == 1: _, second_variable = model.translate_no_unk_beam_search(input_text) logger.info(second_variable) return second_variable custom_theme = Soft() with gr.Blocks(title="南京大学智科NLP小作业-Transformer翻译系统", theme=custom_theme) as demo: en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json') cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json') config = Config('./config.yaml') DEVICE = torch.device(config.config['DEVICE']) model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer) gr.Markdown("# 欢迎试用Demo!!!") with gr.Row(): config_display = gr.Textbox(label="config content", lines=12) with gr.Row(): model_files = get_model_files() # 假设 get_model_files 函数用于获取模型文件列表 logger.info(model_files) model_selector = gr.Dropdown( choices=model_files, label="选择模型", #value=None if not model_files else model_files[0] # 可根据需要设置默认值 ) with gr.Row(): name_input = gr.Textbox(label="输入英文") output_text = gr.Textbox(label="中文翻译") with gr.Row(): greet_btn = gr.Button("翻译") greet_btn.click( fn=wrapper_translate, inputs=name_input, outputs=output_text ) demo.load( fn=show_config, inputs=None, outputs=config_display ) model_selector.change( fn=lambda selected_model: load_model(selected_model, model), inputs=model_selector, outputs=None ) if __name__ == "__main__": demo.launch()