import gradio as gr import os from gradio.themes import default import requests import shutil import uuid from ftplib import FTP from spandrel import ImageModelDescriptor, ModelLoader import torch import subprocess import pkg_resources print("Torch version:",torch.__version__) print("Gradio version:",gr.__version__) mnn_version = pkg_resources.get_distribution("MNN").version print("MNN version:", mnn_version) spandrel_version = pkg_resources.get_distribution("spandrel").version print("Spandrel version:", spandrel_version) pnnx_version = pkg_resources.get_distribution("pnnx").version print("PNNX version:", pnnx_version) # 定义 downloaded_files 变量 downloaded_files = {} # 新增日志开关 log_to_terminal = True # 新增全局任务计数器 task_counter = 0 # 新增日志函数 def print_log(task_id, filename, stage, status): if log_to_terminal: print(f"任务{task_id}: {filename}, [{status}] {stage}") # 修改 start_process 函数,处理新增输入 def start_process(input_file, input_url, input2, shape0_str, shape1_str, output_type, input_suffix=".pth"): global task_counter task_counter += 1 task_id = task_counter input1 = input_file if input_file else input_url print_log(task_id, input2, input1, "input1") print_log(task_id, input2, str(output_type), "output_type") log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n" output_files = [] yield [], log if input2 == None or input2.strip() == "": split_input = os.path.splitext(os.path.basename(input1)) if len(split_input) > 1: suffix = split_input[1].split('?')[0].lower() if suffix not in [".pth" , ".safetensors" , ".ckpt"]: print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误") log += f"不支持此文件的格式\n" return [] , log input2 = split_input[0] print_log(task_id, input2, "检查文件名", "开始") log += f"检查文件名…\n" yield [], log if input2 == None or input2.strip() == "": input2 = str(task_id) log += f"未提供文件名,使用{input2}\n" print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正") yield [], log try: # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持 supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://') if isinstance(input1, str) and input1.startswith(supported_protocols): url = input1 if url in downloaded_files and os.path.exists(downloaded_files[url]): file_path = downloaded_files[url] print_log(task_id, input2, "检查下载状态", "跳过下载") log += f"跳过下载,文件已存在: {file_path}\n" yield [], log else: print_log(task_id, input2, "下载文件", "开始") log += f"开始下载文件…\n" yield [], log # 生成唯一文件名 file_name = str(task_id) + input_suffix file_path = os.path.join(os.getcwd(), file_name) if url.startswith('ftp://'): try: # 解析 ftp 地址 parts = url.replace('ftp://', '').split('/') host = parts[0] remote_file_path = '/'.join(parts[1:]) ftp = FTP(host) ftp.login() with open(file_path, 'wb') as f: ftp.retrbinary('RETR ' + remote_file_path, f.write) ftp.quit() downloaded_files[url] = file_path print_log(task_id, input2, "下载文件", "成功") log += f"文件下载成功: {file_path}\n" yield [], log except Exception as e: print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}") log += f"FTP 文件下载失败: {str(e)}\n" yield [], log return else: if url.startswith(('http://', 'https://')): response = requests.get(url) if response.status_code == 200: with open(file_path, 'wb') as f: f.write(response.content) downloaded_files[url] = file_path print_log(task_id, input2, "下载文件", "成功") log += f"文件下载成功: {file_path}\n" yield [], log else: print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败") log += f"文件下载失败,状态码: {response.status_code}\n" yield [], log return elif input1 is not None: print("check file" , input1, os.path.exists(input1)) file_path = input1 log += f"使用上传的文件: {file_path}\n" print_log(task_id, input2, "使用上传文件", "开始") yield [], log else: log += "未提供有效文件或地址\n" print_log(task_id, input2, "检查文件输入", "失败 (无有效输入)") yield [], log return # 检查文件大小 try: file_size = os.path.getsize(file_path) / 1024 /1024 # 转换为 KB if file_size > 200 : log += f"文件太大,建议 200MB 以内,当前文件大小为 {file_size } MB。\n" print_log(task_id, input2, "文件太大("+ file_size +"MB)", "失败") yield [], log return except Exception as e: log += f"获取文件大小失败: {str(e)}\n" print_log(task_id, input2, "检查文件大小", f"失败: {str(e)}") yield [], log return # 生成新文件夹用于暂存结果 output_folder = os.path.join(os.getcwd(), str(uuid.uuid4())) os.makedirs(output_folder, exist_ok=True) print_log(task_id, input2, "创建临时文件夹", "完成") log += f"创建临时文件夹: {output_folder}\n生成张量\n" yield [], log # 解析输入的字符串为数组 try: # 尝试解析 shape0_str shape0 = [int(x) for x in shape0_str.split(',')] if shape0_str else [0, 0, 0, 0] # 检查 shape0 是否为 4 个元素,如果不是则设置为全 0 if len(shape0) != 4: shape0 = [0, 0, 0, 0] # 尝试解析 shape1_str shape1 = [int(x) for x in shape1_str.split(',')] if shape1_str else [0, 0, 0, 0] # 检查 shape1 是否为 4 个元素,如果不是则设置为全 0 if len(shape1) != 4: shape1 = [0, 0, 0, 0] except ValueError: # 如果解析过程中出现 ValueError,将 shape0 和 shape1 设置为全 0 shape0 = [0, 0, 0, 0] shape1 = [0, 0, 0, 0] log += f"输入的 shape 字符串格式不正确,请使用逗号分隔的整数。shape0_str={shape0_str},shape1_str={shape1_str}\n" yield [], log return # 以下是 process_file 函数的代码 # 使用 torch.rand 生成 input_shape print_log(task_id, input2, "生成输入张量", "开始") log += "生成张量…\n" yield [], log output_base = output_folder + "/" + input2 pt_path = output_base + ".pt" command = f"pnnx {pt_path}" input_tensor0 = torch.rand(shape0) if any(shape0) else None input_tensor1 = torch.rand(shape1) if any(shape1) else None if input_tensor0 is not None and input_tensor1 is not None: example_input = (input_tensor0, input_tensor1) # 修改此处,去除 shape 字符串中的空格 if "Fixed" in output_type: command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}" elif input_tensor0 is not None: example_input = input_tensor0 if "Fixed" in output_type: command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}" else: example_input = input_tensor1 command = f"pnnx {pt_path}" input_tensor_str = "" if input_tensor0 is not None: input_tensor_str += str(input_tensor0.shape) else: input_tensor_str += "None" if input_tensor1 is not None: input_tensor_str += ", " + str(input_tensor1.shape) else: input_tensor_str += ", None" print_log(task_id, input2, "生成输入张量"+input_tensor_str, "完成") log +=input_tensor_str+ "\n" yield [], log # 确保 output_folder 存在 if not os.path.exists(output_folder): os.makedirs(output_folder) print_log(task_id, input2, "加载模型", "开始") log += "加载模型…\n" yield [], log # load a model from disk model = ModelLoader().load_from_file(file_path) # make sure it's an image to image model assert isinstance(model, ImageModelDescriptor) print_log(task_id, input2, "获得模型对象", "开始") log += "获得模型对象…\n" yield [], log # send it to the GPU and put it in inference mode # model.cuda().eval() model.eval() torch_model = model.model print_log(task_id, input2, "获得模型对象", "完成") yield [], log width_ratio = 0 if os.path.exists(pt_path): print_log(task_id, input2, "转换为TorchScript模型", "跳过") log += "跳过转换为TorchScript模型\n" yield [], log elif "TorchScript" in output_type: print_log(task_id, input2, "转换为TorchScript模型", "开始") log+= "转换为TorchScript模型…\n" yield [], log # 使用 torch.jit.trace 进行模型转换 traced_torch_model = torch.jit.trace(torch_model, example_input) traced_torch_model.save(output_folder + "/" + input2 + ".pt") print_log(task_id, input2, "转换为TorchScript模型", "完成") # 获取输出 example_output = traced_torch_model(example_input) if isinstance(example_output, torch.Tensor): width_ratio = example_output.shape[2] / example_input.shape[2] print_log(task_id, input2, "获得缩放倍率="+ str(width_ratio)+", 输出shape="+str(list(example_output.shape)), "完成") log+= ("获得缩放倍率="+str(width_ratio)+", 输出shape="+str(list(example_output.shape))+"\n") yield [], log else: print_log(task_id, input2, "Traced torch model输出" + type(example_output), "错误") log+="Traced torch model输出" + type(example_output)+ "错误\n" yield [], log scale = int(width_ratio) # 转换为 ONNX 模型 if "ONNX" in output_type or "NCNN" in output_type or "MNN" in output_type: if str(scale) in input2 or scale <1: onnx_path = output_base + ".onnx" else: onnx_path = output_base + "-x" + str(scale) + ".onnx" if os.path.exists(onnx_path): print_log(task_id, input2, "转换为ONNX模型", "跳过") log += "跳过转换为ONNX模型\n" yield [], log else: print_log(task_id, input2, "转换为ONNX模型", "开始") log += "转换为ONNX模型…\n" yield [], log torch.onnx.export(torch_model, example_input, onnx_path, opset_version=17, input_names=["input"], output_names=["output"]) # 转换为 mnn 模型 if "MNN" in output_type: if str(scale) in input2 or scale < 1: mnn_path = output_base + ".mnn" else: mnn_path = output_base + "-x" + str(scale) + ".mnn" mnn_config = "" if "Fixed" in output_type and input_tensor0 is not None: mnn_config = output_base + ".mnnconfig" with open(mnn_config, 'w') as f: if input_tensor1 is not None: f.write(f"input_names = input0, input1\n") f.write(f"input_dims = {'x'.join(map(str, shape0))}, {'x'.join(map(str, shape0))},\n") else: f.write(f"input_names = input\n") f.write(f"input_dims = {'x'.join(map(str, shape0))}\n") if os.path.exists(mnn_path): print_log(task_id, input2, "转换为MNN模型", "跳过") log += "跳过转换为MNN模型\n" yield [], log else: print_log(task_id, input2, "转换为MNN模型", "开始") log += "转换为MNN模型…\n" mnn_command = f"MNNConvert -f ONNX --modelFile \"{onnx_path}\" --MNNModel \"{mnn_path}\" --bizCode biz --fp16 --info --detectSparseSpeedUp" if mnn_config: mnn_command += f" --inputConfigFile \"{mnn_config}\"" try: # 使用 subprocess.Popen 执行命令 process = subprocess.Popen(mnn_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) while True: output = process.stdout.readline() if output == '' and process.poll() is not None: break if output: # if log_to_terminal: print(output.strip()) log += output.strip() + '\n' yield [], log returncode = process.poll() if returncode != 0: print_log(task_id, input2, f"转换为MNN模型,返回码: {returncode},命令: {mnn_command} ", "错误") log += f"执行mnn命令失败,返回码: {returncode},命令: {mnn_command} \n" else: log += f"执行mnn命令成功: {mnn_command} \n" except Exception as e: log += f"执行mnn命令: {mnn_command} 失败,错误信息: {str(e)}\n" print_log(task_id, input2, f"转换为MNN模型,错误信息: {str(e)}", "错误") if "NCNN" in output_type: print_log(task_id, input2, "执行ncnn命令" + command, "开始") log += "执行ncnn命令…\n" yield [], log try: # 使用 subprocess.Popen 执行命令 process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) while True: output = process.stdout.readline() if output == '' and process.poll() is not None: break if output: # if log_to_terminal: print(output.strip()) log += output.strip() + '\n' yield [], log returncode = process.poll() if returncode != 0: log += f"执行ncnn命令失败,返回码: {returncode},命令: {command} \n" print_log(task_id, input2, f"返回码: {returncode},命令: {command} ", "失败") else: log += f"执行ncnn命令成功: {command} \n" except Exception as e: log += f"执行ncnn命令: {command} 失败,错误信息: {str(e)}\n" print_log(task_id, input2, f"错误信息: {str(e)}", "错误") # 查找 output_folder 目录下以 .ncnn.bin 和 .ncnn.param 结尾的文件 bin_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.bin')] param_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.param')] if bin_files and param_files: param_file = os.path.join(output_folder, param_files[0]) bin_file = os.path.join(output_folder, bin_files[0]) import zipfile # 压缩包名称 zip_file_name = os.path.join(output_folder, f"models-{input2}.zip") # 压缩包内文件夹名称 zip_folder_name = f"models-{input2}" # 重命名后的文件名 new_bin_name = f"x{scale}.bin" new_param_name = f"x{scale}.param" # 创建压缩包 with zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) as zipf: # 写入重命名后的.bin文件 zipf.write(bin_file, os.path.join(zip_folder_name, new_bin_name)) # 写入重命名后的.param文件 zipf.write(param_file, os.path.join(zip_folder_name, new_param_name)) log += f"已创建压缩包: {zip_file_name}\n" print_log(task_id, input2, "创建压缩包"+zip_file_name, "完成") yield [], log else: log += f"未找到 ncnn 文件\n" print_log(task_id, input2, "查找 ncnn 文件", "失败") yield [], log output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))] log += f"任务完成\n" print_log(task_id, input2, "执行命令", "完成") yield output_files, log except Exception as e: log += f"发生错误: {e}\n" print_log(task_id, input2, e , f"失败") yield [], log # 创建 Gradio 界面 with gr.Blocks() as demo: gr.Markdown("文件处理界面") with gr.Row(): # 左侧列,包含输入组件和按钮 with gr.Column(): # 添加文本提示 gr.Markdown("请输入的url,或者上传一个文件。限制文件为小于100M的*.pth模型") with gr.Row(): input1 = gr.Textbox(label="粘贴地址") # 新增文件上传组件 input1_file = gr.File(label="上传文件", file_types=[".pth", ".safetensors", ".ckpt"]) with gr.Row(): input2 = gr.Textbox(label="自定义文件名") output_type = gr.Dropdown( choices=["TorchScript", "ONNX", "Fixed", "MNN", "NCNN"], value=["TorchScript", "ONNX", "MNN", "NCNN"], multiselect=True, label="模型类型", info="1. 生成mnn和ncnn模型必须先生成onnx模型;2.如果选项中包含了Fixed,那么输出的onnx和mnn模型都使用固定shape的input。" ) shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128") shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0") with gr.Row(): start_button = gr.Button("开始") # 添加取消按钮 cancel_button = gr.Button("取消") # 右侧列,包含输出组件和日志文本框 with gr.Column(): output = gr.File(label="输出文件", file_count="multiple") log_textbox = gr.Textbox(label="日志", lines=10, interactive=False) # 绑定事件,修改输入参数 process = start_button.click( fn=start_process, inputs=[input1_file, input1, input2, shape0_str, shape1_str, output_type], outputs=[output, log_textbox] ) # 为取消按钮添加点击事件绑定,使用 cancels 属性取消 start_process 任务 cancel_button.click( fn=None, inputs=None, outputs=None, cancels=[process] ) # 添加范例 examples = [ [None, "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "", "1,3,128,128", "0,0,0,0", ["TorchScript", "ONNX", "MNN", "ONNX"]], [None, "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth", "", "1,3,128,128", "0,0,0,0", ["TorchScript", "ONNX", "MNN", "ONNX"]], [None, "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0", ["TorchScript", "ONNX", "MNN", "ONNX"]], [None, "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth", "", "1,3,128,128", "0,0,0,0", ["ONNX", "MNN"]], ] gr.Examples( examples=examples, inputs=[input1_file, input1, input2, shape0_str, shape1_str, output_type], outputs=[output, log_textbox], fn=start_process, ) demo.launch(ssr_mode=False)