sachin commited on
Commit
e5a6062
·
1 Parent(s): e0b5384
Files changed (1) hide show
  1. src/server/main.py +63 -0
src/server/main.py CHANGED
@@ -733,6 +733,69 @@ async def chat_v2(
733
  logger.error(f"Error processing request: {str(e)}")
734
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  if __name__ == "__main__":
737
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
738
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
 
733
  logger.error(f"Error processing request: {str(e)}")
734
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
735
 
736
+ class TranscriptionResponse(BaseModel):
737
+ text: str
738
+
739
+
740
+ class ASRModelManager:
741
+ def __init__(self, device_type="cuda"):
742
+ self.device_type = device_type
743
+ self.model_language = {
744
+ "kannada": "kn", "hindi": "hi", "malayalam": "ml", "assamese": "as", "bengali": "bn",
745
+ "bodo": "brx", "dogri": "doi", "gujarati": "gu", "kashmiri": "ks", "konkani": "kok",
746
+ "maithili": "mai", "manipuri": "mni", "marathi": "mr", "nepali": "ne", "odia": "or",
747
+ "punjabi": "pa", "sanskrit": "sa", "santali": "sat", "sindhi": "sd", "tamil": "ta",
748
+ "telugu": "te", "urdu": "ur"
749
+ }
750
+
751
+
752
+ from fastapi import FastAPI, UploadFile
753
+ import torch
754
+ import torchaudio
755
+ from transformers import AutoModel
756
+ import argparse
757
+ import uvicorn
758
+ from pydantic import BaseModel
759
+ from pydub import AudioSegment
760
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query
761
+ from fastapi.responses import RedirectResponse, JSONResponse
762
+ from typing import List
763
+
764
+ # Load the model
765
+ model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
766
+
767
+ asr_manager = ASRModelManager() # Load Kannada, Hindi, Tamil, Telugu, Malayalam
768
+
769
+
770
+ #asr_manager = ASRModelManager(device_type="")
771
+
772
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
773
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
774
+ # Load the uploaded audio file
775
+ wav, sr = torchaudio.load(file.file)
776
+ wav = torch.mean(wav, dim=0, keepdim=True)
777
+
778
+ # Resample if necessary
779
+ target_sample_rate = 16000 # Expected sample rate
780
+ if sr != target_sample_rate:
781
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
782
+ wav = resampler(wav)
783
+
784
+ # Perform ASR with CTC decoding
785
+ #transcription_ctc = model(wav, "kn", "ctc")
786
+
787
+ # Perform ASR with RNNT decoding
788
+ transcription_rnnt = model(wav, "kn", "rnnt")
789
+
790
+ return JSONResponse(content={"text": transcription_rnnt})
791
+
792
+
793
+
794
+ class BatchTranscriptionResponse(BaseModel):
795
+ transcriptions: List[str]
796
+
797
+
798
+
799
  if __name__ == "__main__":
800
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
801
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")