SeyedAli commited on
Commit
76fe6b5
·
1 Parent(s): 1be8004

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -6,17 +6,12 @@ import torchaudio
6
  import gradio as gr
7
  from transformers import Wav2Vec2FeatureExtractor,AutoConfig
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
- from transformers.models.wav2vec2.modeling_wav2vec2 import (
10
- Wav2Vec2PreTrainedModel,
11
- Wav2Vec2Model
12
- )
13
- from transformers.models.hubert.modeling_hubert import (
14
- HubertPreTrainedModel,
15
- HubertModel
16
- )
17
 
18
  config = AutoConfig.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1")
19
- model = Wav2Vec2FeatureExtractor.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1")
 
 
20
 
21
  audio_input = gr.Audio(label="صوت گفتار فارسی",type="filepath")
22
  text_output = gr.TextArea(label="هیجان موجود در صوت گفتار",text_align="right",rtl=True,type="text")
@@ -30,7 +25,7 @@ def SER(audio):
30
  speech_array, _sampling_rate = torchaudio.load(temp_audio_file.name)
31
  resampler = torchaudio.transforms.Resample(_sampling_rate)
32
  speech = resampler(speech_array).squeeze().numpy()
33
- inputs = model(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
34
  inputs = {key: inputs[key].to(device) for key in inputs}
35
 
36
  with torch.no_grad():
 
6
  import gradio as gr
7
  from transformers import Wav2Vec2FeatureExtractor,AutoConfig
8
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+ from models import Wav2Vec2ForSpeechClassification, HubertForSpeechClassification
 
 
 
 
 
 
 
10
 
11
  config = AutoConfig.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1")
12
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1")
13
+ model = HubertForSpeechClassification.from_pretrained("SeyedAli/Persian-Speech-Emotion-HuBert-V1")
14
+ sampling_rate = feature_extractor.sampling_rate
15
 
16
  audio_input = gr.Audio(label="صوت گفتار فارسی",type="filepath")
17
  text_output = gr.TextArea(label="هیجان موجود در صوت گفتار",text_align="right",rtl=True,type="text")
 
25
  speech_array, _sampling_rate = torchaudio.load(temp_audio_file.name)
26
  resampler = torchaudio.transforms.Resample(_sampling_rate)
27
  speech = resampler(speech_array).squeeze().numpy()
28
+ inputs = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
29
  inputs = {key: inputs[key].to(device) for key in inputs}
30
 
31
  with torch.no_grad():