File size: 4,246 Bytes
d4365d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import collections
import numpy as np
import os
import torch
from safetensors.torch import serialize_file
import requests
import tempfile

def download_file(url, local_path):
    """Download a file from a URL to a local path."""
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(local_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    return local_path

def rename_key(rename, name):
    for k, v in rename.items():
        if k in name:
            name = name.replace(k, v)
    return name

def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
    loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]

    kk = list(loaded.keys())
    version = 4
    for x in kk:
        if "ln_x" in x:
            version = max(5, version)
        if "gate.weight" in x:
            version = max(5.1, version)
        if int(version) == 5 and "att.time_decay" in x:
            if len(loaded[x].shape) > 1:
                if loaded[x].shape[1] > 1:
                    version = max(5.2, version)
        if "time_maa" in x:
            version = max(6, version)

    print(f"Model detected: v{version:.1f}")

    if version == 5.1:
        _, n_emb = loaded["emb.weight"].shape
        for k in kk:
            if "time_decay" in k or "time_faaaa" in k:
                loaded[k] = loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])

    with torch.no_grad():
        for k in kk:
            new_k = rename_key(rename, k).lower()
            v = loaded[k].half()
            del loaded[k]
            for transpose_name in transpose_names:
                if transpose_name in new_k:
                    dims = len(v.shape)
                    v = v.transpose(dims - 2, dims - 1)
                    break
            print(f"{new_k}\t{v.shape}\t{v.dtype}")
            loaded[new_k] = {
                "dtype": str(v.dtype).split(".")[-1],
                "shape": v.shape,
                "data": v.numpy().tobytes(),
            }

    os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
    serialize_file(loaded, sf_filename, metadata={"format": "pt"})
    return sf_filename

def process_model(url):
    """Process the model URL and return a downloadable safetensors file."""
    try:
        # Create temporary files
        with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as temp_pth:
            pth_path = temp_pth.name
        with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as temp_sf:
            sf_path = temp_sf.name

        # Download the .pth file from the URL
        download_file(url, pth_path)

        # Conversion parameters
        rename = {"time_faaaa": "time_first", "time_maa": "time_mix", "lora_A": "lora.0", "lora_B": "lora.1"}
        transpose_names = [
            "time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2",
            "w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2",
            "time_state", "lora.0"
        ]

        # Convert the file
        converted_file = convert_file(pth_path, sf_path, rename, transpose_names)

        # Clean up the temporary .pth file
        os.remove(pth_path)

        return converted_file
    except Exception as e:
        # Clean up temporary files in case of error
        if os.path.exists(pth_path):
            os.remove(pth_path)
        if os.path.exists(sf_path):
            os.remove(sf_path)
        raise gr.Error(f"Error processing the model: {str(e)}")

# Gradio interface
with gr.Blocks(title="PTH to Safetensors Converter") as demo:
    gr.Markdown("# PTH to Safetensors Converter")
    gr.Markdown("Enter the URL to a `.pth` model file hosted on Hugging Face to convert it to `.safetensors` format.")
    
    url_input = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/.../model.pth")
    convert_btn = gr.Button("Convert")
    output_file = gr.File(label="Download Converted Safetensors File")
    
    convert_btn.click(
        fn=process_model,
        inputs=url_input,
        outputs=output_file
    )

demo.launch()