amanmibra commited on
Commit
81d87a8
·
1 Parent(s): 5fb3738

Add url support to API

Browse files
Files changed (2) hide show
  1. server/main.py +9 -16
  2. server/preprocess.py +21 -1
server/main.py CHANGED
@@ -4,14 +4,12 @@ sys.path.append('..')
4
 
5
  import os
6
  from fastapi import FastAPI
7
- from pydantic import BaseModel
8
- import wget
9
 
10
  # torch
11
  import torch
12
 
13
  # utils
14
- from preprocess import process_from_filename, process_raw_wav
15
  from cnn import CNNetwork
16
 
17
  # load model
@@ -21,11 +19,6 @@ model.load_state_dict(state_dict)
21
 
22
  print(f"Model loaded! \n {model}")
23
 
24
- # /predict input
25
- # class Data(BaseModel):
26
- # wav:
27
-
28
-
29
  app = FastAPI()
30
 
31
  @app.get("/")
@@ -34,12 +27,13 @@ async def root():
34
 
35
  @app.get("/urlpredict")
36
  def url_predict(url: str):
37
- filename = wget.download(url)
38
- wav = process_from_filename(filename)
39
- print(f"\ntest {wav.shape}\n")
40
 
41
  model_prediction = model_predict(wav)
42
- return model_prediction["predicition_index"]
 
 
 
43
 
44
  @app.put("/predict")
45
  def predict(wav):
@@ -49,16 +43,15 @@ def predict(wav):
49
  model_prediction = model_predict(wav)
50
 
51
  return {
52
- "message": "Voiced Identified!",
53
  "data": model_prediction,
54
  }
55
 
56
  def model_predict(wav):
57
  model_input = wav.unsqueeze(0)
58
  output = model(model_input)
59
- prediction = torch.argmax(output, 1).item()
60
 
61
  return {
62
- "output": output,
63
- "prediction_index": prediction,
64
  }
 
4
 
5
  import os
6
  from fastapi import FastAPI
 
 
7
 
8
  # torch
9
  import torch
10
 
11
  # utils
12
+ from preprocess import process_from_filename, process_from_url, process_raw_wav
13
  from cnn import CNNetwork
14
 
15
  # load model
 
19
 
20
  print(f"Model loaded! \n {model}")
21
 
 
 
 
 
 
22
  app = FastAPI()
23
 
24
  @app.get("/")
 
27
 
28
  @app.get("/urlpredict")
29
  def url_predict(url: str):
30
+ wav = process_from_url(url)
 
 
31
 
32
  model_prediction = model_predict(wav)
33
+ return {
34
+ "message": "Voice Identified!",
35
+ "data": model_prediction,
36
+ }
37
 
38
  @app.put("/predict")
39
  def predict(wav):
 
43
  model_prediction = model_predict(wav)
44
 
45
  return {
46
+ "message": "Voice Identified!",
47
  "data": model_prediction,
48
  }
49
 
50
  def model_predict(wav):
51
  model_input = wav.unsqueeze(0)
52
  output = model(model_input)
53
+ prediction_index = torch.argmax(output, 1).item()
54
 
55
  return {
56
+ "prediction_index": prediction_index,
 
57
  }
server/preprocess.py CHANGED
@@ -1,11 +1,31 @@
1
  """
2
  Util functions to process any incoming audio data to be processable by the model
3
  """
 
 
4
  import torch
5
  import torchaudio
 
 
6
 
7
  DEFAULT_SAMPLE_RATE=48000
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=5):
10
  wav, sample_rate = torchaudio.load(filename)
11
 
@@ -58,6 +78,6 @@ def _pad(wav, num_samples):
58
  if wav.shape[1] < num_samples:
59
  missing_samples = num_samples - wav.shape[1]
60
  pad = (0, missing_samples)
61
- wav = torch.nn.function.pad(wav, pad)
62
 
63
  return wav
 
1
  """
2
  Util functions to process any incoming audio data to be processable by the model
3
  """
4
+ import os
5
+ import librosa
6
  import torch
7
  import torchaudio
8
+ from scipy.io import wavfile
9
+ import wget
10
 
11
  DEFAULT_SAMPLE_RATE=48000
12
 
13
+ def process_from_url(url):
14
+ # download UI audio
15
+ filename = wget.download(url)
16
+ audio, sr = librosa.load(filename)
17
+ wavfile.write('temp.wav', DEFAULT_SAMPLE_RATE, audio)
18
+
19
+ # remove wget file
20
+ os.remove(filename)
21
+
22
+ # spec
23
+ spec = process_from_filename('temp.wav')
24
+
25
+ os.remove('temp.wav')
26
+ return spec
27
+
28
+
29
  def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=5):
30
  wav, sample_rate = torchaudio.load(filename)
31
 
 
78
  if wav.shape[1] < num_samples:
79
  missing_samples = num_samples - wav.shape[1]
80
  pad = (0, missing_samples)
81
+ wav = torch.nn.functional.pad(wav, pad)
82
 
83
  return wav