Aswinthmani's picture
Update main.py
00c2961 verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from enum import Enum
from transformers import pipeline, MarianMTModel, MarianTokenizer
import shutil
import os
import uuid
# Set Hugging Face cache directory (essential for Hugging Face Spaces)
os.environ["HF_HOME"] = "/app/.cache/huggingface"
app = FastAPI()
# CORS for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# βœ… Use smaller model to avoid timeout
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
generator_pipeline = pipeline("text-generation", model="sshleifer/tiny-gpt2")
# Supported languages (dropdown in Swagger UI)
class LanguageEnum(str, Enum):
ta = "ta"
fr = "fr"
es = "es"
de = "de"
it = "it"
hi = "hi"
ru = "ru"
zh = "zh"
ar = "ar"
# Language model mapping
model_map = {
"fr": "Helsinki-NLP/opus-mt-en-fr",
"es": "Helsinki-NLP/opus-mt-en-es",
"de": "Helsinki-NLP/opus-mt-en-de",
"it": "Helsinki-NLP/opus-mt-en-it",
"hi": "Helsinki-NLP/opus-mt-en-hi",
"ru": "Helsinki-NLP/opus-mt-en-ru",
"zh": "Helsinki-NLP/opus-mt-en-zh",
"ar": "Helsinki-NLP/opus-mt-en-ar",
"ta": "Helsinki-NLP/opus-mt-en-ta" # Changed from gsarti to Helsinki version
}
def translate_text(text, target_lang):
if target_lang not in model_map:
return f"No model for language: {target_lang}"
model_name = model_map[target_lang]
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
encoded = tokenizer([text], return_tensors="pt", padding=True)
translated = model.generate(**encoded)
return tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
temp_file = f"temp_{uuid.uuid4().hex}.wav"
with open(temp_file, "wb") as f:
shutil.copyfileobj(audio.file, f)
try:
result = asr_pipeline(temp_file)
return JSONResponse(content={"transcribed_text": result["text"]})
finally:
os.remove(temp_file)
@app.post("/translate")
async def translate(text: str = Form(...), target_lang: LanguageEnum = Form(...)):
translated = translate_text(text, target_lang.value)
return JSONResponse(content={"translated_text": translated})
@app.post("/process")
async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = Form(...)):
temp_file = f"temp_{uuid.uuid4().hex}.wav"
with open(temp_file, "wb") as f:
shutil.copyfileobj(audio.file, f)
try:
result = asr_pipeline(temp_file)
transcribed_text = result["text"]
translated_text = translate_text(transcribed_text, target_lang.value)
return JSONResponse(content={
"transcribed_text": transcribed_text,
"translated_text": translated_text
})
finally:
os.remove(temp_file)
@app.get("/generate")
def generate(prompt: str = "Daily conversation", target_lang: LanguageEnum = LanguageEnum.fr):
english = generator_pipeline(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"].strip()
translated = translate_text(english, target_lang.value)
return {
"prompt": prompt,
"english": english,
"translated": translated
}
@app.get("/")
def root():
return {"message": "βœ… Backend is live!"}