Niral Patel commited on
Commit
6e1a9b7
·
1 Parent(s): fee470d

change in model config

Browse files
Files changed (2) hide show
  1. custom_model.py +3 -1
  2. test.py +15 -15
custom_model.py CHANGED
@@ -23,10 +23,12 @@ class SpleeterModel(PreTrainedModel):
23
  Args:
24
  audio_path (str): Path to the input audio file.
25
  Returns:
26
- dict: Separated stems.
27
  """
28
  return self.separator.separate_to_file(audio_path, "separated_audio")
29
 
30
 
31
  AutoConfig.register("spleeter", SpleeterConfig)
32
  AutoModel.register(SpleeterConfig, SpleeterModel)
 
 
 
23
  Args:
24
  audio_path (str): Path to the input audio file.
25
  Returns:
26
+ path: Separated stems.
27
  """
28
  return self.separator.separate_to_file(audio_path, "separated_audio")
29
 
30
 
31
  AutoConfig.register("spleeter", SpleeterConfig)
32
  AutoModel.register(SpleeterConfig, SpleeterModel)
33
+ SpleeterConfig.register_for_auto_class()
34
+ SpleeterModel.register_for_auto_class("AutoModel")
test.py CHANGED
@@ -1,22 +1,22 @@
1
- from transformers import AutoConfig, AutoModel
2
- from custom_model import SpleeterModel
3
 
4
- config = AutoConfig.from_pretrained("niral-env/youtube_spleeter")
5
- print("----"*30)
6
- print(config)
7
- print("----"*30)
8
 
9
- model = SpleeterModel(config)
10
 
11
- print(model)
12
- result = model.forward("vocals.wav")
13
 
14
- print(result)
15
 
16
 
17
- # from transformers import AutoModel
18
- # model = AutoModel.from_pretrained("niral-env/youtube_spleeter")
19
- # print(model)
20
- # result = model.forward("vocals.wav")
21
 
22
- # print(result)
 
1
+ # from transformers import AutoConfig, AutoModel
2
+ # from custom_model import SpleeterModel
3
 
4
+ # config = AutoConfig.from_pretrained("niral-env/youtube_spleeter")
5
+ # print("----"*30)
6
+ # print(config)
7
+ # print("----"*30)
8
 
9
+ # model = SpleeterModel(config)
10
 
11
+ # print(model)
12
+ # result = model.forward("vocals.wav")
13
 
14
+ # print(result)
15
 
16
 
17
+ from transformers import AutoModel
18
+ model = AutoModel.from_pretrained("niral-env/youtube_spleeter", trust_remote_code=True)
19
+ print(model)
20
+ result = model.forward("vocals.wav")
21
 
22
+ print(result)