Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from enum import Enum | |
from transformers import pipeline, MarianMTModel, MarianTokenizer | |
import shutil | |
import os | |
import uuid | |
# Set Hugging Face cache directory (essential for Hugging Face Spaces) | |
os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
app = FastAPI() | |
# CORS for frontend | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# β Use smaller model to avoid timeout | |
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
generator_pipeline = pipeline("text-generation", model="sshleifer/tiny-gpt2") | |
# Supported languages (dropdown in Swagger UI) | |
class LanguageEnum(str, Enum): | |
ta = "ta" | |
fr = "fr" | |
es = "es" | |
de = "de" | |
it = "it" | |
hi = "hi" | |
ru = "ru" | |
zh = "zh" | |
ar = "ar" | |
# Language model mapping | |
model_map = { | |
"fr": "Helsinki-NLP/opus-mt-en-fr", | |
"es": "Helsinki-NLP/opus-mt-en-es", | |
"de": "Helsinki-NLP/opus-mt-en-de", | |
"it": "Helsinki-NLP/opus-mt-en-it", | |
"hi": "Helsinki-NLP/opus-mt-en-hi", | |
"ru": "Helsinki-NLP/opus-mt-en-ru", | |
"zh": "Helsinki-NLP/opus-mt-en-zh", | |
"ar": "Helsinki-NLP/opus-mt-en-ar", | |
"ta": "Helsinki-NLP/opus-mt-en-ta" # Changed from gsarti to Helsinki version | |
} | |
def translate_text(text, target_lang): | |
if target_lang not in model_map: | |
return f"No model for language: {target_lang}" | |
model_name = model_map[target_lang] | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
encoded = tokenizer([text], return_tensors="pt", padding=True) | |
translated = model.generate(**encoded) | |
return tokenizer.batch_decode(translated, skip_special_tokens=True)[0] | |
async def transcribe(audio: UploadFile = File(...)): | |
temp_file = f"temp_{uuid.uuid4().hex}.wav" | |
with open(temp_file, "wb") as f: | |
shutil.copyfileobj(audio.file, f) | |
try: | |
result = asr_pipeline(temp_file) | |
return JSONResponse(content={"transcribed_text": result["text"]}) | |
finally: | |
os.remove(temp_file) | |
async def translate(text: str = Form(...), target_lang: LanguageEnum = Form(...)): | |
translated = translate_text(text, target_lang.value) | |
return JSONResponse(content={"translated_text": translated}) | |
async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = Form(...)): | |
temp_file = f"temp_{uuid.uuid4().hex}.wav" | |
with open(temp_file, "wb") as f: | |
shutil.copyfileobj(audio.file, f) | |
try: | |
result = asr_pipeline(temp_file) | |
transcribed_text = result["text"] | |
translated_text = translate_text(transcribed_text, target_lang.value) | |
return JSONResponse(content={ | |
"transcribed_text": transcribed_text, | |
"translated_text": translated_text | |
}) | |
finally: | |
os.remove(temp_file) | |
def generate(prompt: str = "Daily conversation", target_lang: LanguageEnum = LanguageEnum.fr): | |
english = generator_pipeline(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"].strip() | |
translated = translate_text(english, target_lang.value) | |
return { | |
"prompt": prompt, | |
"english": english, | |
"translated": translated | |
} | |
def root(): | |
return {"message": "β Backend is live!"} | |