|
import gradio as gr |
|
import torch |
|
from model import Model |
|
from config import Config |
|
|
|
import warnings |
|
|
|
|
|
|
|
config = Config() |
|
|
|
|
|
|
|
def infrence(audio_file1): |
|
print(f"[LOG] Audio file: {audio_file1}") |
|
|
|
class DFSeparationApp: |
|
def __init__(self, model_path,device="cpu"): |
|
self.device = device |
|
self.model = self.load_model(model_path) |
|
self.model.to(self.device) |
|
|
|
|
|
def load_model(self, model_path): |
|
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) |
|
fine_tuned_model = Model( |
|
args=config, |
|
device=self.device |
|
) |
|
fine_tuned_model.load_state_dict(checkpoint["model"]) |
|
print("[LOG] Model loaded successfully.") |
|
return fine_tuned_model |
|
|
|
def predict(self, audio_file): |
|
|
|
audio_tensor = torch.tensor(audio_file[1]).to(self.device) |
|
with torch.no_grad(): |
|
|
|
output = self.model(audio_tensor) |
|
preds = output.argmax(dim=-1) |
|
probs = output.softmax(dim=-1) |
|
print(f"[LOG] Prediction: {preds.item()}") |
|
print(f"[LOG] Probability: {probs.max().item()}") |
|
return preds.item(), probs.max().item() |
|
|
|
def run(self): |
|
print(f"[LOG] Running the app...") |
|
|
|
audio_input1 = gr.Audio(label="Upload or record audio") |
|
prediction = gr.Label(label="Prediction:") |
|
prob = gr.Label(label="Probability:") |
|
gr.Interface( |
|
fn=self.predict, |
|
inputs=[audio_input1], |
|
outputs=[prediction, prob], |
|
title="DF Separation", |
|
description="This app classify the audio samples into Real and Fake.", |
|
examples=[ |
|
["samples/Fake/download (5).wav","1"], |
|
["samples/Fake/fake1_1.wav","1"], |
|
["samples/Real/Central Avenue 1.wav","0"], |
|
["samples/Real/hindi.mp3","0"], |
|
] |
|
).launch(quiet=False,server_name="0.0.0.0") |
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"[LOG] Device: {device}") |
|
model_path = "models/for_trained_model.ckpt" |
|
app = DFSeparationApp(model_path, device=device) |
|
app.run() |
|
|