import gradio as gr import collections import numpy as np import os import torch from safetensors.torch import serialize_file import requests import tempfile def download_file(url, local_path): """Download a file from a URL to a local path.""" response = requests.get(url, stream=True) response.raise_for_status() with open(local_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return local_path def rename_key(rename, name): for k, v in rename.items(): if k in name: name = name.replace(k, v) return name def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]): loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] kk = list(loaded.keys()) version = 4 for x in kk: if "ln_x" in x: version = max(5, version) if "gate.weight" in x: version = max(5.1, version) if int(version) == 5 and "att.time_decay" in x: if len(loaded[x].shape) > 1: if loaded[x].shape[1] > 1: version = max(5.2, version) if "time_maa" in x: version = max(6, version) print(f"Model detected: v{version:.1f}") if version == 5.1: _, n_emb = loaded["emb.weight"].shape for k in kk: if "time_decay" in k or "time_faaaa" in k: loaded[k] = loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) with torch.no_grad(): for k in kk: new_k = rename_key(rename, k).lower() v = loaded[k].half() del loaded[k] for transpose_name in transpose_names: if transpose_name in new_k: dims = len(v.shape) v = v.transpose(dims - 2, dims - 1) break print(f"{new_k}\t{v.shape}\t{v.dtype}") loaded[new_k] = { "dtype": str(v.dtype).split(".")[-1], "shape": v.shape, "data": v.numpy().tobytes(), } os.makedirs(os.path.dirname(sf_filename), exist_ok=True) serialize_file(loaded, sf_filename, metadata={"format": "pt"}) return sf_filename def process_model(url): """Process the model URL and return a downloadable safetensors file.""" try: # Create temporary files with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as temp_pth: pth_path = temp_pth.name with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as temp_sf: sf_path = temp_sf.name # Download the .pth file from the URL download_file(url, pth_path) # Conversion parameters rename = {"time_faaaa": "time_first", "time_maa": "time_mix", "lora_A": "lora.0", "lora_B": "lora.1"} transpose_names = [ "time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2", "w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2", "time_state", "lora.0" ] # Convert the file converted_file = convert_file(pth_path, sf_path, rename, transpose_names) # Clean up the temporary .pth file os.remove(pth_path) return converted_file except Exception as e: # Clean up temporary files in case of error if os.path.exists(pth_path): os.remove(pth_path) if os.path.exists(sf_path): os.remove(sf_path) raise gr.Error(f"Error processing the model: {str(e)}") # Gradio interface with gr.Blocks(title="PTH to Safetensors Converter") as demo: gr.Markdown("# PTH to Safetensors Converter") gr.Markdown("Enter the URL to a `.pth` model file hosted on Hugging Face to convert it to `.safetensors` format.") url_input = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/.../model.pth") convert_btn = gr.Button("Convert") output_file = gr.File(label="Download Converted Safetensors File") convert_btn.click( fn=process_model, inputs=url_input, outputs=output_file ) demo.launch()