mizoru commited on
Commit
d78118d
·
1 Parent(s): e8af28e

add version choice

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -11,11 +11,19 @@ def get_x(df):
11
  def get_y(df):
12
  return df.pattern
13
 
14
- learn = load_learner('xresnet50_pitch3_removeSilence.pkl')
15
 
16
- labels = learn.dls.vocab
17
 
18
- def predict(Record, Upload):
 
 
 
 
 
 
 
 
19
  if Upload: path = Upload
20
  else: path = Record
21
  spec,pred,pred_idx,probs = learn.predict(str(path), with_input=True)
@@ -35,5 +43,10 @@ examples = [['代わる.mp3'],['大丈夫な.mp3'],['熱くない.mp3'], ['あ
35
 
36
  enable_queue=True
37
 
38
- gr.Interface(fn=predict,inputs=[gr.inputs.Audio(source='microphone', type='filepath', optional=True), gr.inputs.Audio(source='upload', type='filepath', optional=True)], outputs= [gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="plot", label='Spectrogram')], title=title,description=description,article=article,examples=examples).launch(debug=True, enable_queue=enable_queue)
 
 
 
 
 
39
 
 
11
  def get_y(df):
12
  return df.pattern
13
 
14
+ learn_removeSilence = load_learner('xresnet50_pitch3_removeSilence.pkl')
15
 
16
+ learn_plain = load_learner('xresnet50_pitch3.pkl')
17
 
18
+ labels = learn_removeSilence.dls.vocab
19
+
20
+ def process(Record, Upload, version):
21
+ if version == 'remove silence':
22
+ return predict(Record, Upload, learn_removeSilence)
23
+ elif version == 'plain':
24
+ return predict(Record, Upload, learn_plain)
25
+
26
+ def predict(Record, Upload, learn):
27
  if Upload: path = Upload
28
  else: path = Record
29
  spec,pred,pred_idx,probs = learn.predict(str(path), with_input=True)
 
43
 
44
  enable_queue=True
45
 
46
+ gr.Interface(fn=predict,
47
+ inputs=[gr.inputs.Audio(source='microphone', type='filepath', optional=True),
48
+ gr.inputs.Audio(source='upload', type='filepath', optional=True),
49
+ gr.inputs.Radio(choices=['plain','remove silence'], type="value", default='remove silence', label='version')
50
+ ],
51
+ outputs= [gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="plot", label='Spectrogram')], title=title,description=description,article=article,examples=examples).launch(debug=True, enable_queue=enable_queue)
52