Spaces:
Runtime error
Runtime error
patrickvonplaten
commited on
Commit
·
c6c5536
1
Parent(s):
38707b6
Update convert.py
Browse files- convert.py +1 -53
convert.py
CHANGED
@@ -133,57 +133,6 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
|
|
133 |
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
|
134 |
return "\n".join(errors)
|
135 |
|
136 |
-
|
137 |
-
def check_final_model(model_id: str, folder: str):
|
138 |
-
config = hf_hub_download(repo_id=model_id, filename="config.json")
|
139 |
-
shutil.copy(config, os.path.join(folder, "config.json"))
|
140 |
-
config = AutoConfig.from_pretrained(folder)
|
141 |
-
|
142 |
-
_, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
|
143 |
-
_, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
|
144 |
-
|
145 |
-
if pt_infos != sf_infos:
|
146 |
-
error_string = create_diff(pt_infos, sf_infos)
|
147 |
-
raise ValueError(f"Different infos when reloading the model: {error_string}")
|
148 |
-
|
149 |
-
pt_params = pt_model.state_dict()
|
150 |
-
sf_params = sf_model.state_dict()
|
151 |
-
|
152 |
-
pt_shared = shared_pointers(pt_params)
|
153 |
-
sf_shared = shared_pointers(sf_params)
|
154 |
-
if pt_shared != sf_shared:
|
155 |
-
raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}")
|
156 |
-
|
157 |
-
sig = signature(pt_model.forward)
|
158 |
-
input_ids = torch.arange(10).unsqueeze(0)
|
159 |
-
pixel_values = torch.randn(1, 3, 224, 224)
|
160 |
-
input_values = torch.arange(1000).float().unsqueeze(0)
|
161 |
-
kwargs = {}
|
162 |
-
if "input_ids" in sig.parameters:
|
163 |
-
kwargs["input_ids"] = input_ids
|
164 |
-
if "decoder_input_ids" in sig.parameters:
|
165 |
-
kwargs["decoder_input_ids"] = input_ids
|
166 |
-
if "pixel_values" in sig.parameters:
|
167 |
-
kwargs["pixel_values"] = pixel_values
|
168 |
-
if "input_values" in sig.parameters:
|
169 |
-
kwargs["input_values"] = input_values
|
170 |
-
if "bbox" in sig.parameters:
|
171 |
-
kwargs["bbox"] = torch.zeros((1, 10, 4)).long()
|
172 |
-
if "image" in sig.parameters:
|
173 |
-
kwargs["image"] = pixel_values
|
174 |
-
|
175 |
-
if torch.cuda.is_available():
|
176 |
-
pt_model = pt_model.cuda()
|
177 |
-
sf_model = sf_model.cuda()
|
178 |
-
kwargs = {k: v.cuda() for k, v in kwargs.items()}
|
179 |
-
|
180 |
-
pt_logits = pt_model(**kwargs)[0]
|
181 |
-
sf_logits = sf_model(**kwargs)[0]
|
182 |
-
|
183 |
-
torch.testing.assert_close(sf_logits, pt_logits)
|
184 |
-
print(f"Model {model_id} is ok !")
|
185 |
-
|
186 |
-
|
187 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
188 |
try:
|
189 |
discussions = api.get_repo_discussions(repo_id=model_id)
|
@@ -218,7 +167,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
218 |
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
|
219 |
pr_title = "Adding `safetensors` variant of this model"
|
220 |
info = api.model_info(model_id)
|
221 |
-
filenames = set(s.rfilename for s in info.siblings)
|
222 |
|
223 |
with TemporaryDirectory() as d:
|
224 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
@@ -242,7 +191,6 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
242 |
operations = convert_multi(model_id, folder)
|
243 |
else:
|
244 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
245 |
-
check_final_model(model_id, folder)
|
246 |
else:
|
247 |
operations = convert_generic(model_id, folder, filenames)
|
248 |
|
|
|
133 |
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
|
134 |
return "\n".join(errors)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
137 |
try:
|
138 |
discussions = api.get_repo_discussions(repo_id=model_id)
|
|
|
167 |
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
|
168 |
pr_title = "Adding `safetensors` variant of this model"
|
169 |
info = api.model_info(model_id)
|
170 |
+
filenames = set(s.rfilename for s in info.siblings if len(s.rfilename.split("/")) > 1)
|
171 |
|
172 |
with TemporaryDirectory() as d:
|
173 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
|
|
191 |
operations = convert_multi(model_id, folder)
|
192 |
else:
|
193 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
|
|
194 |
else:
|
195 |
operations = convert_generic(model_id, folder, filenames)
|
196 |
|