saronium commited on
Commit
6e4bc1e
·
verified ·
1 Parent(s): c0f1a73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -5,19 +5,38 @@ from torchvision import models
5
  from scipy.ndimage import zoom
6
  import gradio as gr
7
  import pickle
 
8
 
9
 
10
  # Assuming you already have the 'ann_model' trained and 'pca' instance from the previous code
11
  language_mapping = {'malayalam': 0, 'english': 1, 'tamil': 2,'hindi':3,'kannada':4,'telugu':5}
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Load the trained model
15
- with open('ann_model.pth', 'rb') as f:
16
- ann_model = torch.load(f)
17
 
18
  # Load the PCA instance
19
- with open('pca.pkl', 'rb') as f:
20
- pca = pickle.load(f)
21
  vgg16 = models.vgg16(pretrained=True).features
22
  # Function to load and preprocess a single audio file
23
  def preprocess_single_audio_vgg16(audio_file, vgg16_model, pca_instance):
 
5
  from scipy.ndimage import zoom
6
  import gradio as gr
7
  import pickle
8
+ from joblib import load
9
 
10
 
11
  # Assuming you already have the 'ann_model' trained and 'pca' instance from the previous code
12
  language_mapping = {'malayalam': 0, 'english': 1, 'tamil': 2,'hindi':3,'kannada':4,'telugu':5}
13
 
14
+ class ANNModel(nn.Module):
15
+ def __init__(self):
16
+ super(ANNModel, self).__init__()
17
+ self.fc1 = nn.Linear(300, 128)
18
+ self.relu1 = nn.ReLU()
19
+ self.fc2 = nn.Linear(128, 64)
20
+ self.relu2 = nn.ReLU()
21
+ self.fc3 = nn.Linear(64, 6)
22
+
23
+ def forward(self, x):
24
+ x = self.fc1(x)
25
+ x = self.relu1(x)
26
+ x = self.fc2(x)
27
+ x = self.relu2(x)
28
+ x = self.fc3(x)
29
+ return x
30
+
31
+ # Create an instance of your model
32
+ ann_model = ANNModel()
33
 
34
  # Load the trained model
35
+ ann_model.load_state_dict(torch.load('ann_model.pth'))
 
36
 
37
  # Load the PCA instance
38
+ pca = load('pca.pkl')
39
+
40
  vgg16 = models.vgg16(pretrained=True).features
41
  # Function to load and preprocess a single audio file
42
  def preprocess_single_audio_vgg16(audio_file, vgg16_model, pca_instance):