Bradarr commited on
Commit
d4365d2
·
verified ·
1 Parent(s): cdaeea0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import collections
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from safetensors.torch import serialize_file
7
+ import requests
8
+ import tempfile
9
+
10
+ def download_file(url, local_path):
11
+ """Download a file from a URL to a local path."""
12
+ response = requests.get(url, stream=True)
13
+ response.raise_for_status()
14
+ with open(local_path, 'wb') as f:
15
+ for chunk in response.iter_content(chunk_size=8192):
16
+ f.write(chunk)
17
+ return local_path
18
+
19
+ def rename_key(rename, name):
20
+ for k, v in rename.items():
21
+ if k in name:
22
+ name = name.replace(k, v)
23
+ return name
24
+
25
+ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
26
+ loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
27
+ if "state_dict" in loaded:
28
+ loaded = loaded["state_dict"]
29
+
30
+ kk = list(loaded.keys())
31
+ version = 4
32
+ for x in kk:
33
+ if "ln_x" in x:
34
+ version = max(5, version)
35
+ if "gate.weight" in x:
36
+ version = max(5.1, version)
37
+ if int(version) == 5 and "att.time_decay" in x:
38
+ if len(loaded[x].shape) > 1:
39
+ if loaded[x].shape[1] > 1:
40
+ version = max(5.2, version)
41
+ if "time_maa" in x:
42
+ version = max(6, version)
43
+
44
+ print(f"Model detected: v{version:.1f}")
45
+
46
+ if version == 5.1:
47
+ _, n_emb = loaded["emb.weight"].shape
48
+ for k in kk:
49
+ if "time_decay" in k or "time_faaaa" in k:
50
+ loaded[k] = loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
51
+
52
+ with torch.no_grad():
53
+ for k in kk:
54
+ new_k = rename_key(rename, k).lower()
55
+ v = loaded[k].half()
56
+ del loaded[k]
57
+ for transpose_name in transpose_names:
58
+ if transpose_name in new_k:
59
+ dims = len(v.shape)
60
+ v = v.transpose(dims - 2, dims - 1)
61
+ break
62
+ print(f"{new_k}\t{v.shape}\t{v.dtype}")
63
+ loaded[new_k] = {
64
+ "dtype": str(v.dtype).split(".")[-1],
65
+ "shape": v.shape,
66
+ "data": v.numpy().tobytes(),
67
+ }
68
+
69
+ os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
70
+ serialize_file(loaded, sf_filename, metadata={"format": "pt"})
71
+ return sf_filename
72
+
73
+ def process_model(url):
74
+ """Process the model URL and return a downloadable safetensors file."""
75
+ try:
76
+ # Create temporary files
77
+ with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as temp_pth:
78
+ pth_path = temp_pth.name
79
+ with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as temp_sf:
80
+ sf_path = temp_sf.name
81
+
82
+ # Download the .pth file from the URL
83
+ download_file(url, pth_path)
84
+
85
+ # Conversion parameters
86
+ rename = {"time_faaaa": "time_first", "time_maa": "time_mix", "lora_A": "lora.0", "lora_B": "lora.1"}
87
+ transpose_names = [
88
+ "time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2",
89
+ "w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2",
90
+ "time_state", "lora.0"
91
+ ]
92
+
93
+ # Convert the file
94
+ converted_file = convert_file(pth_path, sf_path, rename, transpose_names)
95
+
96
+ # Clean up the temporary .pth file
97
+ os.remove(pth_path)
98
+
99
+ return converted_file
100
+ except Exception as e:
101
+ # Clean up temporary files in case of error
102
+ if os.path.exists(pth_path):
103
+ os.remove(pth_path)
104
+ if os.path.exists(sf_path):
105
+ os.remove(sf_path)
106
+ raise gr.Error(f"Error processing the model: {str(e)}")
107
+
108
+ # Gradio interface
109
+ with gr.Blocks(title="PTH to Safetensors Converter") as demo:
110
+ gr.Markdown("# PTH to Safetensors Converter")
111
+ gr.Markdown("Enter the URL to a `.pth` model file hosted on Hugging Face to convert it to `.safetensors` format.")
112
+
113
+ url_input = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/.../model.pth")
114
+ convert_btn = gr.Button("Convert")
115
+ output_file = gr.File(label="Download Converted Safetensors File")
116
+
117
+ convert_btn.click(
118
+ fn=process_model,
119
+ inputs=url_input,
120
+ outputs=output_file
121
+ )
122
+
123
+ demo.launch()