from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from enum import Enum from transformers import pipeline, MarianMTModel, MarianTokenizer import shutil import os import uuid # Set Hugging Face cache directory (essential for Hugging Face Spaces) os.environ["HF_HOME"] = "/app/.cache/huggingface" app = FastAPI() # CORS for frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ✅ Use smaller model to avoid timeout asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") generator_pipeline = pipeline("text-generation", model="sshleifer/tiny-gpt2") # Supported languages (dropdown in Swagger UI) class LanguageEnum(str, Enum): ta = "ta" fr = "fr" es = "es" de = "de" it = "it" hi = "hi" ru = "ru" zh = "zh" ar = "ar" # Language model mapping model_map = { "fr": "Helsinki-NLP/opus-mt-en-fr", "es": "Helsinki-NLP/opus-mt-en-es", "de": "Helsinki-NLP/opus-mt-en-de", "it": "Helsinki-NLP/opus-mt-en-it", "hi": "Helsinki-NLP/opus-mt-en-hi", "ru": "Helsinki-NLP/opus-mt-en-ru", "zh": "Helsinki-NLP/opus-mt-en-zh", "ar": "Helsinki-NLP/opus-mt-en-ar", "ta": "Helsinki-NLP/opus-mt-en-ta" # Changed from gsarti to Helsinki version } def translate_text(text, target_lang): if target_lang not in model_map: return f"No model for language: {target_lang}" model_name = model_map[target_lang] tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) encoded = tokenizer([text], return_tensors="pt", padding=True) translated = model.generate(**encoded) return tokenizer.batch_decode(translated, skip_special_tokens=True)[0] @app.post("/transcribe") async def transcribe(audio: UploadFile = File(...)): temp_file = f"temp_{uuid.uuid4().hex}.wav" with open(temp_file, "wb") as f: shutil.copyfileobj(audio.file, f) try: result = asr_pipeline(temp_file) return JSONResponse(content={"transcribed_text": result["text"]}) finally: os.remove(temp_file) @app.post("/translate") async def translate(text: str = Form(...), target_lang: LanguageEnum = Form(...)): translated = translate_text(text, target_lang.value) return JSONResponse(content={"translated_text": translated}) @app.post("/process") async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = Form(...)): temp_file = f"temp_{uuid.uuid4().hex}.wav" with open(temp_file, "wb") as f: shutil.copyfileobj(audio.file, f) try: result = asr_pipeline(temp_file) transcribed_text = result["text"] translated_text = translate_text(transcribed_text, target_lang.value) return JSONResponse(content={ "transcribed_text": transcribed_text, "translated_text": translated_text }) finally: os.remove(temp_file) @app.get("/generate") def generate(prompt: str = "Daily conversation", target_lang: LanguageEnum = LanguageEnum.fr): english = generator_pipeline(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"].strip() translated = translate_text(english, target_lang.value) return { "prompt": prompt, "english": english, "translated": translated } @app.get("/") def root(): return {"message": "✅ Backend is live!"}