AIML_project / app.py
VDNT11's picture
Update app.py
c3e8af7 verified
raw
history blame
2.58 kB
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import streamlit as st
from pydub import AudioSegment
import os
import soundfile as sf
import uuid
# Set device and dtype
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@st.cache_resource
def load_model():
# Use a specific Hindi-optimized Whisper model
model_id = "openai/whisper-large-v2" # or consider a multilingual model
# For Hindi, you might want to specify additional parameters
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
# Use the processor from the same model
processor = AutoProcessor.from_pretrained(model_id)
# Create pipeline with language specification
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
generate_kwargs={"language": "hi"} # Specify Hindi language
)
return pipe, processor
# Load model and processor
pipe, processor = load_model()
# Streamlit UI
st.title("Hindi Audio to Text Transcription")
uploaded_file = st.file_uploader(
"Upload a .wav audio file for transcription", type=["wav"]
)
if uploaded_file is not None:
st.info("Processing uploaded file...")
temp_filename = f"temp_audio_{uuid.uuid4()}.wav"
with open(temp_filename, "wb") as f:
f.write(uploaded_file.read())
# Preprocess the audio
sound = AudioSegment.from_file(temp_filename)
sound = sound.set_channels(1) # Convert to mono
sound.export(temp_filename, format="wav") # Save the processed file
audio, _ = sf.read(temp_filename) # Read audio data
# Preprocess the audio for the model
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform transcription
with torch.no_grad():
outputs = pipe.model.generate(**inputs)
transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
# Display the transcription
st.success("Transcription complete!")
st.markdown(f"### Transcription:\n\n{transcription}")
os.remove(temp_filename) # Clean up temporary file
else:
st.warning("Please upload a .wav file to start transcription.")