akshansh36 commited on
Commit
03d15d8
1 Parent(s): 2540ac0

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +77 -160
infer.py CHANGED
@@ -1,156 +1,78 @@
 
1
  import torch
2
  import numpy as np
3
  import time
4
- import sounddevice as sd
5
- import torchaudio
6
- import json
7
- from infer_rvc_python import BaseLoader
8
  import datetime
9
- import pyaudio
10
- # Get the current date and time
 
 
11
  now = datetime.datetime.now()
12
- # Format the date and time as a string
13
  timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
 
14
 
15
- converter = BaseLoader(only_cpu=True, hubert_path='./hubert_base.pt', rmvpe_path='./rmvpe.pt')
16
- random_tag = "USER_"+str(timestamp)
17
  converter.apply_conf(
18
- tag=random_tag,
19
- file_model="./model.pth",
20
- pitch_algo="rmvpe+",
21
- pitch_lvl=0,
22
- file_index="./model.index",
23
- index_influence=0.80,
24
- respiration_median_filtering=3,
25
- envelope_ratio=0.25,
26
- consonant_breath_protection=0.5,
27
- resample_sr=0,
28
- )
29
- time.sleep(0.5)
 
30
  chunk_sec = 0.1
31
  sr = 16000
32
  chunk_len = int(sr * chunk_sec)
33
  L = 16
34
- b, a = converter.generate_from_cache(
35
- audio_data="./AKSHAY KUMAR.wav",
36
- tag=random_tag,
37
- )
38
- import soundfile as sf
39
 
40
- sf.write(
41
- file="output_file.wav",
42
- samplerate=a,
43
- data=b
44
- )
45
- stop_recording = False
46
- def infer_stream(sr, max_duration):
47
- global start_time
48
- global first_output_latency
49
- global audio_buffer
50
 
51
- previous_chunk = torch.zeros(L * 2, dtype=torch.float32)
52
-
53
- outputs = []
54
- times = []
55
- elapsed_time = 0
56
-
57
- with torch.inference_mode():
58
- while True:
59
- if len(audio_buffer) < chunk_len:
60
- print(f'Buffer too small')
61
- time.sleep(0.1)
62
- continue # Wait for enough data
63
-
64
- # Get the current chunk
65
- buffer_chunk = audio_buffer[:chunk_len]
66
- audio_buffer = audio_buffer[chunk_len:]
67
-
68
- # Add lookahead context
69
- input_chunk = torch.cat([previous_chunk, buffer_chunk])
70
- start = time.time()
71
- # todo:
72
- data = (input_chunk.numpy().astype(np.int16), sr)
73
- print(data)
74
- result_array, sample_rate = converter.generate_from_cache(
75
- audio_data=data,
76
- tag=random_tag,
77
- )
78
-
79
- if first_output_latency < 1:
80
- first_output_latency = time.time() - start_time
81
- print(f'first_output_latency {first_output_latency}')
82
- # Convert the NumPy array (result_array) to a PyTorch tensor
83
- output = torch.tensor(result_array, dtype=torch.float32)
84
- outputs.append(output)
85
- times.append(time.time() - start)
86
-
87
- # Update the previous chunk with the last part of the current buffer_chunk
88
- previous_chunk = buffer_chunk[-L * 2:]
89
-
90
- # Check if the maximum duration has been reached
91
- elapsed_time = time.time() - start_time
92
- if elapsed_time > max_duration/1.2 and len(audio_buffer) < chunk_len:
93
- break
94
- else:
95
- print(f'Audio Buffer At Processing: {len(audio_buffer)} elapsed_time {elapsed_time}/{max_duration}')
96
-
97
- # Concatenate outputs and calculate metrics
98
- if outputs:
99
- outputs = torch.cat(outputs, dim=2)
100
- avg_time = np.mean(times)
101
- total_time_processing = np.sum(times)
102
- rtf = (chunk_len / sr) / avg_time
103
- e2e_latency = ((2 * L + chunk_len) / sr + avg_time) * 1000
104
- outputs = outputs.squeeze(0)
105
- else:
106
- rtf = e2e_latency = None
107
 
108
- return outputs, rtf, e2e_latency, total_time_processing
 
109
 
110
- def save_audio(audio, audio_path, sample_rate):
111
- torchaudio.save(audio_path, audio, sample_rate)
112
- max_duration = 2 # Maximum duration to process in seconds
113
- silence_threshold = 0.01 # Threshold to detect silence
114
- max_silence_duration = 0.5 # Maximum duration of silence to keep in seconds
115
-
116
- # Variable to track accumulated silence duration
117
- accumulated_silence_duration = 0.0
118
- # Callback function to process audio from mic
119
- def callback(indata, frames, time_info, status):
120
- global audio_buffer, accumulated_silence_duration
121
- global stop_recording
122
- global stop_pro
123
- if stop_recording:
124
- if stop_pro < 10:
125
- stop_pro+=1
126
- print(f'Audio Buffer Stopped Recording: {len(audio_buffer)}')
127
- return
128
- audio_data = indata[:, 0] # Use first channel if stereo
129
- audio_data = torch.tensor(audio_data, dtype=torch.float32)
130
-
131
- # Convert audio data to numpy for silence detection
132
- audio_np = audio_data.numpy()
133
-
134
- # Detect silence (audio below the threshold)
135
- silence_indices = np.where(np.abs(audio_np) < silence_threshold)[0]
136
- # Calculate the duration of the current chunk in seconds
137
- chunk_duration = len(audio_np) / sample_rate
138
-
139
- if len(silence_indices) == len(audio_np):
140
- # All data is silent
141
- accumulated_silence_duration += chunk_duration
142
-
143
- if accumulated_silence_duration <= max_silence_duration:
144
- audio_buffer = torch.cat((audio_buffer, audio_data))
145
  else:
146
- # Non-silence detected, reset accumulated silence duration
147
- accumulated_silence_duration = 0.0
148
- audio_buffer = torch.cat((audio_buffer, audio_data))
149
-
150
- if time.time()-start_time > max_duration:
151
- stop_recording = True
152
- print(f'Audio Buffer At Insert: {len(audio_buffer)}')
153
  def list_audio_devices():
 
154
  audio = pyaudio.PyAudio()
155
  device_count = audio.get_device_count()
156
 
@@ -158,33 +80,28 @@ def list_audio_devices():
158
  for i in range(device_count):
159
  device_info = audio.get_device_info_by_index(i)
160
  print(f"Index: {i}, Name: {device_info['name']}, Input Channels: {device_info['maxInputChannels']}, Output Channels: {device_info['maxOutputChannels']}")
161
- # Main script
162
- if __name__ == "__main__":
163
- list_audio_devices()
164
- chunk_size = 2024 # Size of each audio chunk
165
- sample_rate = sr # Sample rate from the model config
166
- stop_pro = 0
167
 
168
- # Initialize global audio buffer
169
- audio_buffer = torch.zeros(0, dtype=torch.float32)
 
 
 
170
 
171
- # Set up the microphone stream
172
- input_device_index = 2 # Replace with your actual input device index
 
 
 
173
 
174
- print("Recording...")
 
 
175
  start_time = time.time()
176
  first_output_latency = 0
177
- final_output_latency = 0
178
- try:
179
- with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback, blocksize=chunk_size, device=input_device_index):
180
- output_waveform, rtf, e2e_latency, total_processing_time = infer_stream(sr=sample_rate, max_duration=max_duration)
181
- if output_waveform is not None:
182
- # Save output to file
183
- final_output_latency = (time.time() - start_time) - (len(output_waveform[0])/sample_rate)
184
- save_audio(output_waveform, f'output_audio_stream_buff-{now}.wav', sample_rate)
185
- print(f"Processed audio saved to output_audio_stream_buff.wav")
186
- print(f'first_output_latency: {first_output_latency} || final_output_latency {final_output_latency} || total_processing_time {total_processing_time}')
187
- if rtf is not None and e2e_latency is not None:
188
- print(f"RTF: {rtf}, E2E Latency: {e2e_latency} ms")
189
- except KeyboardInterrupt:
190
- print("Recording stopped.")
 
1
+ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import time
5
+ import soundfile as sf
 
 
 
6
  import datetime
7
+ from infer_rvc_python import BaseLoader
8
+
9
+ # Initialize converter and other global variables
10
+ converter = BaseLoader(only_cpu=True, hubert_path='./hubert_base.pt', rmvpe_path='./rmvpe.pt')
11
  now = datetime.datetime.now()
 
12
  timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
13
+ random_tag = "USER_" + str(timestamp)
14
 
 
 
15
  converter.apply_conf(
16
+ tag=random_tag,
17
+ file_model="./model.pth",
18
+ pitch_algo="rmvpe+",
19
+ pitch_lvl=0,
20
+ file_index="./model.index",
21
+ index_influence=0.80,
22
+ respiration_median_filtering=3,
23
+ envelope_ratio=0.25,
24
+ consonant_breath_protection=0.5,
25
+ resample_sr=0,
26
+ )
27
+
28
+ # Constants and initializations
29
  chunk_sec = 0.1
30
  sr = 16000
31
  chunk_len = int(sr * chunk_sec)
32
  L = 16
 
 
 
 
 
33
 
34
+ # Define the streaming function for Gradio
35
+ def process_audio_stream(audio, instream):
36
+ global audio_buffer, start_time, first_output_latency, stop_recording
 
 
 
 
 
 
 
37
 
38
+ if audio is None:
39
+ return gr.update(), instream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ if instream is None:
42
+ instream = torch.zeros(0, dtype=torch.float32)
43
 
44
+ # Assuming 'audio' is received as numpy array, convert to torch tensor
45
+ audio_data = torch.tensor(audio[1], dtype=torch.float32)
46
+
47
+ # Append new data to audio buffer
48
+ audio_buffer = torch.cat((audio_buffer, audio_data))
49
+
50
+ if len(audio_buffer) >= chunk_len:
51
+ # Get the current chunk
52
+ buffer_chunk = audio_buffer[:chunk_len]
53
+ audio_buffer = audio_buffer[chunk_len:]
54
+
55
+ # Process the audio data (as per your existing logic)
56
+ input_chunk = torch.cat([instream[-L*2:], buffer_chunk])
57
+ data = (input_chunk.numpy().astype(np.int16), sr)
58
+
59
+ result_array, _ = converter.generate_from_cache(audio_data=data, tag=random_tag)
60
+ output = torch.tensor(result_array, dtype=torch.float32)
61
+
62
+ # Append the processed output to instream for continuous processing
63
+ instream = torch.cat((instream, output))
64
+
65
+ return instream.numpy(), instream.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  else:
67
+ return gr.update(), instream
68
+
69
+ # Function to save audio to file
70
+ def save_audio(audio, audio_path, sample_rate):
71
+ torchaudio.save(audio_path, torch.tensor(audio, dtype=torch.float32), sample_rate)
72
+
73
+ # Function to list audio devices (for debugging or selecting specific devices)
74
  def list_audio_devices():
75
+ import pyaudio
76
  audio = pyaudio.PyAudio()
77
  device_count = audio.get_device_count()
78
 
 
80
  for i in range(device_count):
81
  device_info = audio.get_device_info_by_index(i)
82
  print(f"Index: {i}, Name: {device_info['name']}, Input Channels: {device_info['maxInputChannels']}, Output Channels: {device_info['maxOutputChannels']}")
 
 
 
 
 
 
83
 
84
+ # Define Gradio interface
85
+ with gr.Blocks() as demo:
86
+ inp = gr.Audio(sources="microphone", streaming=True)
87
+ out = gr.Audio(streaming=True)
88
+ stream = gr.State()
89
 
90
+ inp.stream(process_audio_stream, [inp, stream], [out, stream])
91
+
92
+ # Button to clear/reset the stream
93
+ clear = gr.Button("Clear")
94
+ clear.click(lambda: [None, torch.zeros(0, dtype=torch.float32)], None, [inp, out, stream])
95
 
96
+ if __name__ == "__main__":
97
+ # Initialize global audio buffer
98
+ audio_buffer = torch.zeros(0, dtype=torch.float32)
99
  start_time = time.time()
100
  first_output_latency = 0
101
+ stop_recording = False
102
+
103
+ # Optionally list audio devices (can be commented out if not needed)
104
+ # list_audio_devices()
105
+
106
+ # Launch Gradio interface
107
+ demo.launch()