Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from transformers import pipeline, MarianMTModel, MarianTokenizer, WhisperProcessor, WhisperForConditionalGeneration | |
import torch | |
import tempfile | |
import soundfile as sf | |
app = FastAPI() | |
# Allow frontend to call backend | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Supported languages | |
translation_models = { | |
"fr": "Helsinki-NLP/opus-mt-en-fr", | |
"es": "Helsinki-NLP/opus-mt-en-es", | |
"de": "Helsinki-NLP/opus-mt-en-de", | |
"it": "Helsinki-NLP/opus-mt-en-it", | |
"hi": "Helsinki-NLP/opus-mt-en-hi", | |
"ru": "Helsinki-NLP/opus-mt-en-ru", | |
"zh": "Helsinki-NLP/opus-mt-en-zh", | |
"ar": "Helsinki-NLP/opus-mt-en-ar", | |
"ta": "Helsinki-NLP/opus-mt-en-ta" | |
} | |
# Load models once | |
generator = pipeline("text-generation", model="distilgpt2", framework="tf", from_tf=True) | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base") | |
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base") | |
def root(): | |
return {"message": "Backend is live β "} | |
def generate_and_translate(prompt: str, target_lang: str): | |
try: | |
if target_lang not in translation_models: | |
return {"error": "Unsupported language."} | |
# 1. Generate English sentence | |
result = generator(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"] | |
english_sentence = result.strip() | |
# 2. Translate | |
model_name = translation_models[target_lang] | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
tokens = tokenizer(english_sentence, return_tensors="pt", padding=True) | |
translated_ids = model.generate(**tokens) | |
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return {"english": english_sentence, "translated": translated_text} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
class TranslateRequest(BaseModel): | |
text: str | |
target_lang: str | |
def translate_text(data: TranslateRequest): | |
try: | |
if data.target_lang not in translation_models: | |
return {"error": "Unsupported language."} | |
model_name = translation_models[data.target_lang] | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
tokens = tokenizer(data.text, return_tensors="pt", padding=True) | |
translated_ids = model.generate(**tokens) | |
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return {"translated_text": translated_text} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
async def transcribe_audio(audio: UploadFile = File(...)): | |
try: | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
temp_file.write(await audio.read()) | |
temp_file.close() | |
audio_data, _ = sf.read(temp_file.name) | |
inputs = whisper_processor(audio_data, sampling_rate=16000, return_tensors="pt") | |
predicted_ids = whisper_model.generate(inputs["input_features"]) | |
transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True) | |
return {"transcribed_text": transcription} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
async def transcribe_and_translate_audio( | |
audio: UploadFile = File(...), | |
target_lang: str = Form(...) | |
): | |
try: | |
if target_lang not in translation_models: | |
return {"error": "Unsupported language."} | |
# Save uploaded file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
temp_file.write(await audio.read()) | |
temp_file.close() | |
# Transcribe | |
audio_data, _ = sf.read(temp_file.name) | |
inputs = whisper_processor(audio_data, sampling_rate=16000, return_tensors="pt") | |
predicted_ids = whisper_model.generate(inputs["input_features"]) | |
transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True) | |
# Translate | |
model_name = translation_models[target_lang] | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
tokens = tokenizer(transcription, return_tensors="pt", padding=True) | |
translated_ids = model.generate(**tokens) | |
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return { | |
"transcribed_text": transcription, | |
"translated_text": translated_text | |
} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |