ksang commited on
Commit
057d8eb
·
1 Parent(s): ad16c3d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -20,8 +20,8 @@ def load_pickle(file_path: str, mode: str = "rb", encoding=""):
20
  return pickle.load(f, encoding=encoding)
21
 
22
  # %%
23
- label2id = load_pickle('/data/audio-classification-pytorch/wav2vec2/results/best/label2id.pkl')
24
- id2label = load_pickle('/data/audio-classification-pytorch/wav2vec2/results/best/id2label.pkl')
25
 
26
  # %%
27
  model = AutoModelForAudioClassification.from_pretrained(
@@ -32,7 +32,7 @@ model = AutoModelForAudioClassification.from_pretrained(
32
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
33
 
34
  # %%
35
- checkpoint = torch.load('/data/audio-classification-pytorch/wav2vec2/results/best/pytorch_model.bin')
36
 
37
  # %%
38
  model.load_state_dict(checkpoint)
@@ -52,7 +52,6 @@ def predict(input):
52
  label_name = id2label[str(label_id)]
53
 
54
  return label_name
55
-
56
  # %%
57
  demo = gr.Interface(
58
  fn=predict,
 
20
  return pickle.load(f, encoding=encoding)
21
 
22
  # %%
23
+ label2id = load_pickle('label2id.pkl')
24
+ id2label = load_pickle('id2label.pkl')
25
 
26
  # %%
27
  model = AutoModelForAudioClassification.from_pretrained(
 
32
  feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
33
 
34
  # %%
35
+ checkpoint = torch.load('pytorch_model.bin')
36
 
37
  # %%
38
  model.load_state_dict(checkpoint)
 
52
  label_name = id2label[str(label_id)]
53
 
54
  return label_name
 
55
  # %%
56
  demo = gr.Interface(
57
  fn=predict,