admin commited on
Commit
a2ea81c
·
1 Parent(s): 6bbbe9a

fix examples

Browse files
Files changed (2) hide show
  1. app.py +2 -3
  2. utils.py +7 -3
app.py CHANGED
@@ -376,12 +376,11 @@ def infer(wav_path: str, log_name: str, folder_path=TEMP_DIR):
376
 
377
  if __name__ == "__main__":
378
  warnings.filterwarnings("ignore")
379
- models = get_modelist()
380
  examples = []
381
  example_wavs = find_files()
382
- model_num = len(models)
383
  for wav in example_wavs:
384
- examples.append([wav, models[random.randint(0, model_num - 1)]])
385
 
386
  with gr.Blocks() as demo:
387
  gr.Interface(
 
376
 
377
  if __name__ == "__main__":
378
  warnings.filterwarnings("ignore")
379
+ models = get_modelist(assign_model="regnet_y_32gf_cqt")
380
  examples = []
381
  example_wavs = find_files()
 
382
  for wav in example_wavs:
383
+ examples.append([wav, models[0]])
384
 
385
  with gr.Blocks() as demo:
386
  gr.Interface(
utils.py CHANGED
@@ -29,7 +29,7 @@ def find_files(folder_path=f"{MODEL_DIR}/examples", ext=".wav"):
29
  return wav_files
30
 
31
 
32
- def get_modelist(model_dir=MODEL_DIR):
33
  try:
34
  entries = os.listdir(model_dir)
35
  except OSError as e:
@@ -40,11 +40,15 @@ def get_modelist(model_dir=MODEL_DIR):
40
  for entry in entries:
41
  full_path = os.path.join(model_dir, entry)
42
  if entry == ".git" or entry == "examples":
43
- print(f"Skip .git or examples dir: {full_path}")
44
  continue
45
 
46
  if os.path.isdir(full_path):
47
- output.append(os.path.basename(full_path))
 
 
 
 
48
 
49
  return output
50
 
 
29
  return wav_files
30
 
31
 
32
+ def get_modelist(model_dir=MODEL_DIR, assign_model=""):
33
  try:
34
  entries = os.listdir(model_dir)
35
  except OSError as e:
 
40
  for entry in entries:
41
  full_path = os.path.join(model_dir, entry)
42
  if entry == ".git" or entry == "examples":
43
+ print(f"Skip .git / examples dir: {full_path}")
44
  continue
45
 
46
  if os.path.isdir(full_path):
47
+ model = os.path.basename(full_path)
48
+ if assign_model and assign_model.lower() in model:
49
+ output.insert(0, model)
50
+ else:
51
+ output.append(model)
52
 
53
  return output
54