jsd219 commited on
Commit
1f7d3e4
·
verified ·
1 Parent(s): 0f31a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -42
app.py CHANGED
@@ -1,53 +1,22 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import pipeline
 
4
 
5
  model_id = "ntu-spml/distilhubert"
6
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
  pipe = pipeline("audio-classification", model=model_id, device=device)
8
 
9
- # def predict_trunc(filepath):
10
- # preprocessed = pipe.preprocess(filepath)
11
- # truncated = pipe.feature_extractor.pad(preprocessed,truncation=True, max_length = 16_000*30)
12
- # model_outputs = pipe.forward(truncated)
13
- # outputs = pipe.postprocess(model_outputs)
14
-
15
- # return outputs
16
-
17
-
18
  def classify_audio(filepath):
19
  import time
20
- start_time = time.time()
21
-
22
- # Assuming `pipe` is your model pipeline for inference
23
  preds = pipe(filepath)
24
-
25
- outputs = {}
26
- for p in preds:
27
- outputs[p["label"]] = p["score"]
28
-
29
- end_time = time.time()
30
- prediction_time = end_time - start_time
31
-
32
- return outputs, prediction_time
33
-
34
 
35
- title = "🎵 Music Genre Classifier"
36
- description = """
37
- Music Genre Classifier model (Fine-tuned "ntu-spml/distilhubert") Dataset: [GTZAN](https://huggingface.co/datasets/marsyas/gtzan)
38
- """
39
-
40
- filenames = ['rock-it-21275.mp3']
41
- filenames = [f"./{f}" for f in filenames]
42
-
43
- demo = gr.Interface(
44
  fn=classify_audio,
45
- inputs=gr.Audio(type="filepath"),
46
- outputs=[gr.Label(), gr.Number(label="Prediction time (s)")],
47
- title=title,
48
- description=description,
49
- )
50
-
51
-
52
-
53
- demo.launch()
 
1
  import gradio as gr
 
2
  from transformers import pipeline
3
+ import torch
4
 
5
  model_id = "ntu-spml/distilhubert"
6
+ device = 0 if torch.cuda.is_available() else -1
7
  pipe = pipeline("audio-classification", model=model_id, device=device)
8
 
 
 
 
 
 
 
 
 
 
9
  def classify_audio(filepath):
10
  import time
11
+ start = time.time()
 
 
12
  preds = pipe(filepath)
13
+ result = {p["label"]: round(p["score"], 3) for p in preds}
14
+ return result, round(time.time() - start, 2)
 
 
 
 
 
 
 
 
15
 
16
+ gr.Interface(
 
 
 
 
 
 
 
 
17
  fn=classify_audio,
18
+ inputs=gr.Audio(type="filepath", label="Upload Audio"),
19
+ outputs=[gr.Label(label="Top Genres"), gr.Number(label="Time (s)")],
20
+ title="🎵 Music Genre Classifier",
21
+ description="Classifies the genre of uploaded audio using DistilHuBERT fine-tuned on GTZAN."
22
+ ).launch()