Gregniuki commited on
Commit
175e3f9
1 Parent(s): 4058a36

Delete model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +0 -580
model/utils.py DELETED
@@ -1,580 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import re
5
- import math
6
- import random
7
- import string
8
- from tqdm import tqdm
9
- from collections import defaultdict
10
-
11
- import matplotlib
12
- matplotlib.use("Agg")
13
- import matplotlib.pylab as plt
14
-
15
- import torch
16
- import torch.nn.functional as F
17
- from torch.nn.utils.rnn import pad_sequence
18
- import torchaudio
19
-
20
- import einx
21
- from einops import rearrange, reduce
22
-
23
- import jieba
24
- from pypinyin import lazy_pinyin, Style
25
-
26
- from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
- from model.modules import MelSpec
28
-
29
-
30
- # seed everything
31
-
32
- def seed_everything(seed = 0):
33
- random.seed(seed)
34
- os.environ['PYTHONHASHSEED'] = str(seed)
35
- torch.manual_seed(seed)
36
- torch.cuda.manual_seed(seed)
37
- torch.cuda.manual_seed_all(seed)
38
- torch.backends.cudnn.deterministic = True
39
- torch.backends.cudnn.benchmark = False
40
-
41
- # helpers
42
-
43
- def exists(v):
44
- return v is not None
45
-
46
- def default(v, d):
47
- return v if exists(v) else d
48
-
49
- # tensor helpers
50
-
51
- def lens_to_mask(
52
- t: int['b'],
53
- length: int | None = None
54
- ) -> bool['b n']:
55
-
56
- if not exists(length):
57
- length = t.amax()
58
-
59
- seq = torch.arange(length, device = t.device)
60
- return einx.less('n, b -> b n', seq, t)
61
-
62
- def mask_from_start_end_indices(
63
- seq_len: int['b'],
64
- start: int['b'],
65
- end: int['b']
66
- ):
67
- max_seq_len = seq_len.max().item()
68
- seq = torch.arange(max_seq_len, device = start.device).long()
69
- return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
-
71
- def mask_from_frac_lengths(
72
- seq_len: int['b'],
73
- frac_lengths: float['b']
74
- ):
75
- lengths = (frac_lengths * seq_len).long()
76
- max_start = seq_len - lengths
77
-
78
- rand = torch.rand_like(frac_lengths)
79
- start = (max_start * rand).long().clamp(min = 0)
80
- end = start + lengths
81
-
82
- return mask_from_start_end_indices(seq_len, start, end)
83
-
84
- def maybe_masked_mean(
85
- t: float['b n d'],
86
- mask: bool['b n'] = None
87
- ) -> float['b d']:
88
-
89
- if not exists(mask):
90
- return t.mean(dim = 1)
91
-
92
- t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
- num = reduce(t, 'b n d -> b d', 'sum')
94
- den = reduce(mask.float(), 'b n -> b', 'sum')
95
-
96
- return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
-
98
-
99
- # simple utf-8 tokenizer, since paper went character based
100
- def list_str_to_tensor(
101
- text: list[str],
102
- padding_value = -1
103
- ) -> int['b nt']:
104
- list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
- text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
- return text
107
-
108
- # char tokenizer, based on custom dataset's extracted .txt file
109
- def list_str_to_idx(
110
- text: list[str] | list[list[str]],
111
- vocab_char_map: dict[str, int], # {char: idx}
112
- padding_value = -1
113
- ) -> int['b nt']:
114
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
- text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
- return text
117
-
118
-
119
- # Get tokenizer
120
-
121
- def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
- '''
123
- tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
- - "char" for char-wise tokenizer, need .txt vocab_file
125
- - "byte" for utf-8 tokenizer
126
- - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
- vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
- - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
- - if use "byte", set to 256 (unicode byte range)
130
- '''
131
- if tokenizer in ["pinyin", "char"]:
132
- with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
133
- vocab_char_map = {}
134
- for i, char in enumerate(f):
135
- vocab_char_map[char[:-1]] = i
136
- vocab_size = len(vocab_char_map)
137
- assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
138
-
139
- elif tokenizer == "byte":
140
- vocab_char_map = None
141
- vocab_size = 256
142
- elif tokenizer == "custom":
143
- with open (dataset_name, "r", encoding="utf-8") as f:
144
- vocab_char_map = {}
145
- for i, char in enumerate(f):
146
- vocab_char_map[char[:-1]] = i
147
- vocab_size = len(vocab_char_map)
148
-
149
- return vocab_char_map, vocab_size
150
-
151
-
152
- # convert char to pinyin
153
-
154
- def convert_char_to_pinyin(text_list, polyphone = True):
155
- final_text_list = []
156
- god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
- custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
158
- for text in text_list:
159
- char_list = []
160
- text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
- text = text.translate(custom_trans)
162
- for seg in jieba.cut(text):
163
- seg_byte_len = len(bytes(seg, 'UTF-8'))
164
- if seg_byte_len == len(seg): # if pure alphabets and symbols
165
- if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
- char_list.append(" ")
167
- char_list.extend(seg)
168
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
169
- seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
170
- for c in seg:
171
- if c not in "。,、;:?!《》【】—…":
172
- char_list.append(" ")
173
- char_list.append(c)
174
- else: # if mixed chinese characters, alphabets and symbols
175
- for c in seg:
176
- if ord(c) < 256:
177
- char_list.extend(c)
178
- else:
179
- if c not in "。,、;:?!《》【】—…":
180
- char_list.append(" ")
181
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
182
- else: # if is zh punc
183
- char_list.append(c)
184
- final_text_list.append(char_list)
185
-
186
- return final_text_list
187
-
188
-
189
- # save spectrogram
190
- def save_spectrogram(spectrogram, path):
191
- plt.figure(figsize=(12, 4))
192
- plt.imshow(spectrogram, origin='lower', aspect='auto')
193
- plt.colorbar()
194
- plt.savefig(path)
195
- plt.close()
196
-
197
-
198
- # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
- def get_seedtts_testset_metainfo(metalst):
200
- f = open(metalst); lines = f.readlines(); f.close()
201
- metainfo = []
202
- for line in lines:
203
- if len(line.strip().split('|')) == 5:
204
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
- elif len(line.strip().split('|')) == 4:
206
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
- gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
- if not os.path.isabs(prompt_wav):
209
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
210
- metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
211
- return metainfo
212
-
213
-
214
- # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
- def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
- f = open(metalst); lines = f.readlines(); f.close()
217
- metainfo = []
218
- for line in lines:
219
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
-
221
- # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
-
225
- # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
-
229
- metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
-
231
- return metainfo
232
-
233
-
234
- # padded to max length mel batch
235
- def padded_mel_batch(ref_mels):
236
- max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
- padded_ref_mels = []
238
- for mel in ref_mels:
239
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
- padded_ref_mels.append(padded_ref_mel)
241
- padded_ref_mels = torch.stack(padded_ref_mels)
242
- padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
- return padded_ref_mels
244
-
245
-
246
- # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
-
248
- def get_inference_prompt(
249
- metainfo,
250
- speed = 1., tokenizer = "pinyin", polyphone = True,
251
- target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
- use_truth_duration = False,
253
- infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
254
- ):
255
- prompts_all = []
256
-
257
- min_tokens = min_secs * target_sample_rate // hop_length
258
- max_tokens = max_secs * target_sample_rate // hop_length
259
-
260
- batch_accum = [0] * num_buckets
261
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
- ([[] for _ in range(num_buckets)] for _ in range(6))
263
-
264
- mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
265
-
266
- for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
-
268
- # Audio
269
- ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
- ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
271
- if ref_rms < target_rms:
272
- ref_audio = ref_audio * target_rms / ref_rms
273
- assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
274
- if ref_sr != target_sample_rate:
275
- resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
276
- ref_audio = resampler(ref_audio)
277
-
278
- # Text
279
- if len(prompt_text[-1].encode('utf-8')) == 1:
280
- prompt_text = prompt_text + " "
281
- text = [prompt_text + gt_text]
282
- if tokenizer == "pinyin":
283
- text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
- else:
285
- text_list = text
286
-
287
- # Duration, mel frame length
288
- ref_mel_len = ref_audio.shape[-1] // hop_length
289
- if use_truth_duration:
290
- gt_audio, gt_sr = torchaudio.load(gt_wav)
291
- if gt_sr != target_sample_rate:
292
- resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
293
- gt_audio = resampler(gt_audio)
294
- total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
295
-
296
- # # test vocoder resynthesis
297
- # ref_audio = gt_audio
298
- else:
299
- zh_pause_punc = r"。,、;:?!"
300
- ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
- gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
- total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
-
304
- # to mel spectrogram
305
- ref_mel = mel_spectrogram(ref_audio)
306
- ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
-
308
- # deal with batch
309
- assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
- assert min_tokens <= total_mel_len <= max_tokens, \
311
- f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
312
- bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
-
314
- utts[bucket_i].append(utt)
315
- ref_rms_list[bucket_i].append(ref_rms)
316
- ref_mels[bucket_i].append(ref_mel)
317
- ref_mel_lens[bucket_i].append(ref_mel_len)
318
- total_mel_lens[bucket_i].append(total_mel_len)
319
- final_text_list[bucket_i].extend(text_list)
320
-
321
- batch_accum[bucket_i] += total_mel_len
322
-
323
- if batch_accum[bucket_i] >= infer_batch_size:
324
- # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
- prompts_all.append((
326
- utts[bucket_i],
327
- ref_rms_list[bucket_i],
328
- padded_mel_batch(ref_mels[bucket_i]),
329
- ref_mel_lens[bucket_i],
330
- total_mel_lens[bucket_i],
331
- final_text_list[bucket_i]
332
- ))
333
- batch_accum[bucket_i] = 0
334
- utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
335
-
336
- # add residual
337
- for bucket_i, bucket_frames in enumerate(batch_accum):
338
- if bucket_frames > 0:
339
- prompts_all.append((
340
- utts[bucket_i],
341
- ref_rms_list[bucket_i],
342
- padded_mel_batch(ref_mels[bucket_i]),
343
- ref_mel_lens[bucket_i],
344
- total_mel_lens[bucket_i],
345
- final_text_list[bucket_i]
346
- ))
347
- # not only leave easy work for last workers
348
- random.seed(666)
349
- random.shuffle(prompts_all)
350
-
351
- return prompts_all
352
-
353
-
354
- # get wav_res_ref_text of seed-tts test metalst
355
- # https://github.com/BytedanceSpeech/seed-tts-eval
356
-
357
- def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
- f = open(metalst)
359
- lines = f.readlines()
360
- f.close()
361
-
362
- test_set_ = []
363
- for line in tqdm(lines):
364
- if len(line.strip().split('|')) == 5:
365
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
- elif len(line.strip().split('|')) == 4:
367
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
-
369
- if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
- continue
371
- gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
- if not os.path.isabs(prompt_wav):
373
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
-
375
- test_set_.append((gen_wav, prompt_wav, gt_text))
376
-
377
- num_jobs = len(gpus)
378
- if num_jobs == 1:
379
- return [(gpus[0], test_set_)]
380
-
381
- wav_per_job = len(test_set_) // num_jobs + 1
382
- test_set = []
383
- for i in range(num_jobs):
384
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
-
386
- return test_set
387
-
388
-
389
- # get librispeech test-clean cross sentence test
390
-
391
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
392
- f = open(metalst)
393
- lines = f.readlines()
394
- f.close()
395
-
396
- test_set_ = []
397
- for line in tqdm(lines):
398
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
-
400
- if eval_ground_truth:
401
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
- else:
404
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
- raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
- gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
-
408
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
-
411
- test_set_.append((gen_wav, ref_wav, gen_txt))
412
-
413
- num_jobs = len(gpus)
414
- if num_jobs == 1:
415
- return [(gpus[0], test_set_)]
416
-
417
- wav_per_job = len(test_set_) // num_jobs + 1
418
- test_set = []
419
- for i in range(num_jobs):
420
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
-
422
- return test_set
423
-
424
-
425
- # load asr model
426
-
427
- def load_asr_model(lang, ckpt_dir = ""):
428
- if lang == "zh":
429
- from funasr import AutoModel
430
- model = AutoModel(
431
- model = os.path.join(ckpt_dir, "paraformer-zh"),
432
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
- # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
- # spk_model = os.path.join(ckpt_dir, "cam++"),
435
- disable_update=True,
436
- ) # following seed-tts setting
437
- elif lang == "en":
438
- from faster_whisper import WhisperModel
439
- model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
- model = WhisperModel(model_size, device="cpu", compute_type="float16")
441
- return model
442
-
443
-
444
- # WER Evaluation, the way Seed-TTS does
445
-
446
- def run_asr_wer(args):
447
- rank, lang, test_set, ckpt_dir = args
448
-
449
- if lang == "zh":
450
- import zhconv
451
- torch.cuda.set_device(rank)
452
- elif lang == "en":
453
- os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
- else:
455
- raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
-
457
- asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
-
459
- from zhon.hanzi import punctuation
460
- punctuation_all = punctuation + string.punctuation
461
- wers = []
462
-
463
- from jiwer import compute_measures
464
- for gen_wav, prompt_wav, truth in tqdm(test_set):
465
- if lang == "zh":
466
- res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
- hypo = res[0]["text"]
468
- hypo = zhconv.convert(hypo, 'zh-cn')
469
- elif lang == "en":
470
- segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
- hypo = ''
472
- for segment in segments:
473
- hypo = hypo + ' ' + segment.text
474
-
475
- # raw_truth = truth
476
- # raw_hypo = hypo
477
-
478
- for x in punctuation_all:
479
- truth = truth.replace(x, '')
480
- hypo = hypo.replace(x, '')
481
-
482
- truth = truth.replace(' ', ' ')
483
- hypo = hypo.replace(' ', ' ')
484
-
485
- if lang == "zh":
486
- truth = " ".join([x for x in truth])
487
- hypo = " ".join([x for x in hypo])
488
- elif lang == "en":
489
- truth = truth.lower()
490
- hypo = hypo.lower()
491
-
492
- measures = compute_measures(truth, hypo)
493
- wer = measures["wer"]
494
-
495
- # ref_list = truth.split(" ")
496
- # subs = measures["substitutions"] / len(ref_list)
497
- # dele = measures["deletions"] / len(ref_list)
498
- # inse = measures["insertions"] / len(ref_list)
499
-
500
- wers.append(wer)
501
-
502
- return wers
503
-
504
-
505
- # SIM Evaluation
506
-
507
- def run_sim(args):
508
- rank, test_set, ckpt_dir = args
509
- device = f"cuda:{rank}"
510
-
511
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
- model.load_state_dict(state_dict['model'], strict=False)
514
-
515
- use_gpu=True if torch.cuda.is_available() else False
516
- if use_gpu:
517
- model = model.cuda(device)
518
- model.eval()
519
-
520
- sim_list = []
521
- for wav1, wav2, truth in tqdm(test_set):
522
-
523
- wav1, sr1 = torchaudio.load(wav1)
524
- wav2, sr2 = torchaudio.load(wav2)
525
-
526
- resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
527
- resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
528
- wav1 = resample1(wav1)
529
- wav2 = resample2(wav2)
530
-
531
- if use_gpu:
532
- wav1 = wav1.cuda(device)
533
- wav2 = wav2.cuda(device)
534
- with torch.no_grad():
535
- emb1 = model(wav1)
536
- emb2 = model(wav2)
537
-
538
- sim = F.cosine_similarity(emb1, emb2)[0].item()
539
- # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
- sim_list.append(sim)
541
-
542
- return sim_list
543
-
544
-
545
- # filter func for dirty data with many repetitions
546
-
547
- def repetition_found(text, length = 2, tolerance = 10):
548
- pattern_count = defaultdict(int)
549
- for i in range(len(text) - length + 1):
550
- pattern = text[i:i + length]
551
- pattern_count[pattern] += 1
552
- for pattern, count in pattern_count.items():
553
- if count > tolerance:
554
- return True
555
- return False
556
-
557
-
558
- # load model checkpoint for inference
559
-
560
- def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
- from ema_pytorch import EMA
562
-
563
- ckpt_type = ckpt_path.split(".")[-1]
564
- if ckpt_type == "safetensors":
565
- from safetensors.torch import load_file
566
- checkpoint = load_file(ckpt_path, device=device)
567
- else:
568
- checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
-
570
- if use_ema == True:
571
- ema_model = EMA(model, include_online_model = False).to(device)
572
- if ckpt_type == "safetensors":
573
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
574
- else:
575
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
- ema_model.copy_params_from_ema_to_model()
577
- else:
578
- model.load_state_dict(checkpoint['model_state_dict'])
579
-
580
- return model