Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
@@ -1,28 +1,25 @@
|
|
1 |
from fastapi import FastAPI, File, UploadFile, Form
|
|
|
2 |
from fastapi.responses import JSONResponse
|
3 |
-
from
|
4 |
-
from transformers import pipeline, MarianMTModel, MarianTokenizer
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import
|
8 |
-
from googletrans import Translator
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
# π Map target language to translation model
|
25 |
-
model_map = {
|
26 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
27 |
"es": "Helsinki-NLP/opus-mt-en-es",
|
28 |
"de": "Helsinki-NLP/opus-mt-en-de",
|
@@ -31,78 +28,108 @@ model_map = {
|
|
31 |
"ru": "Helsinki-NLP/opus-mt-en-ru",
|
32 |
"zh": "Helsinki-NLP/opus-mt-en-zh",
|
33 |
"ar": "Helsinki-NLP/opus-mt-en-ar",
|
34 |
-
"ta": "
|
35 |
}
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
model_name = model_map[target_lang]
|
50 |
-
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
51 |
-
model = MarianMTModel.from_pretrained(model_name)
|
52 |
-
encoded = tokenizer([text], return_tensors="pt", padding=True)
|
53 |
-
translated = model.generate(**encoded)
|
54 |
-
return tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
55 |
-
|
56 |
-
# π§ Generate a random English sentence
|
57 |
-
def generate_random_sentence(prompt="Daily conversation", max_length=30):
|
58 |
-
generator = pipeline("text-generation", model="distilgpt2")
|
59 |
-
result = generator(prompt, max_length=max_length, num_return_sequences=1)
|
60 |
-
return result[0]["generated_text"].strip()
|
61 |
-
|
62 |
-
# π€ Transcription endpoint
|
63 |
-
@app.post("/transcribe")
|
64 |
-
async def transcribe(audio: UploadFile = File(...)):
|
65 |
-
temp_filename = f"temp_{uuid.uuid4().hex}.wav"
|
66 |
-
with open(temp_filename, "wb") as f:
|
67 |
-
shutil.copyfileobj(audio.file, f)
|
68 |
try:
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
# π Translation endpoint
|
76 |
@app.post("/translate")
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
# π Combined endpoint (speech-to-translation)
|
82 |
@app.post("/process")
|
83 |
-
async def
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
try:
|
88 |
-
|
89 |
-
|
90 |
-
transcribed_text = result["text"]
|
91 |
-
translated_text = translate_text(transcribed_text, target_lang.value)
|
92 |
-
return JSONResponse(content={
|
93 |
-
"transcribed_text": transcribed_text,
|
94 |
-
"translated_text": translated_text
|
95 |
-
})
|
96 |
-
finally:
|
97 |
-
os.remove(temp_filename)
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
"
|
107 |
-
"
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile, Form
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from fastapi.responses import JSONResponse
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from transformers import pipeline, MarianMTModel, MarianTokenizer, WhisperProcessor, WhisperForConditionalGeneration
|
6 |
+
import torch
|
7 |
+
import tempfile
|
8 |
+
import soundfile as sf
|
|
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
12 |
+
# Allow frontend to call backend
|
13 |
+
app.add_middleware(
|
14 |
+
CORSMiddleware,
|
15 |
+
allow_origins=["*"],
|
16 |
+
allow_credentials=True,
|
17 |
+
allow_methods=["*"],
|
18 |
+
allow_headers=["*"],
|
19 |
+
)
|
20 |
+
|
21 |
+
# Supported languages
|
22 |
+
translation_models = {
|
|
|
|
|
|
|
23 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
24 |
"es": "Helsinki-NLP/opus-mt-en-es",
|
25 |
"de": "Helsinki-NLP/opus-mt-en-de",
|
|
|
28 |
"ru": "Helsinki-NLP/opus-mt-en-ru",
|
29 |
"zh": "Helsinki-NLP/opus-mt-en-zh",
|
30 |
"ar": "Helsinki-NLP/opus-mt-en-ar",
|
31 |
+
"ta": "Helsinki-NLP/opus-mt-en-ta"
|
32 |
}
|
33 |
|
34 |
+
# Load models once
|
35 |
+
generator = pipeline("text-generation", model="gpt2")
|
36 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
|
37 |
+
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
|
38 |
+
|
39 |
+
@app.get("/")
|
40 |
+
def root():
|
41 |
+
return {"message": "Backend is live β
"}
|
42 |
+
|
43 |
+
@app.get("/generate")
|
44 |
+
def generate_and_translate(prompt: str, target_lang: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
try:
|
46 |
+
if target_lang not in translation_models:
|
47 |
+
return {"error": "Unsupported language."}
|
48 |
+
|
49 |
+
# 1. Generate English sentence
|
50 |
+
result = generator(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"]
|
51 |
+
english_sentence = result.strip()
|
52 |
+
|
53 |
+
# 2. Translate
|
54 |
+
model_name = translation_models[target_lang]
|
55 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
56 |
+
model = MarianMTModel.from_pretrained(model_name)
|
57 |
+
tokens = tokenizer(english_sentence, return_tensors="pt", padding=True)
|
58 |
+
translated_ids = model.generate(**tokens)
|
59 |
+
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
60 |
+
|
61 |
+
return {"english": english_sentence, "translated": translated_text}
|
62 |
+
except Exception as e:
|
63 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
64 |
+
|
65 |
+
class TranslateRequest(BaseModel):
|
66 |
+
text: str
|
67 |
+
target_lang: str
|
68 |
|
|
|
69 |
@app.post("/translate")
|
70 |
+
def translate_text(data: TranslateRequest):
|
71 |
+
try:
|
72 |
+
if data.target_lang not in translation_models:
|
73 |
+
return {"error": "Unsupported language."}
|
74 |
+
|
75 |
+
model_name = translation_models[data.target_lang]
|
76 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
77 |
+
model = MarianMTModel.from_pretrained(model_name)
|
78 |
+
tokens = tokenizer(data.text, return_tensors="pt", padding=True)
|
79 |
+
translated_ids = model.generate(**tokens)
|
80 |
+
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
81 |
+
|
82 |
+
return {"translated_text": translated_text}
|
83 |
+
except Exception as e:
|
84 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
85 |
+
|
86 |
+
@app.post("/transcribe")
|
87 |
+
async def transcribe_audio(audio: UploadFile = File(...)):
|
88 |
+
try:
|
89 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
90 |
+
temp_file.write(await audio.read())
|
91 |
+
temp_file.close()
|
92 |
+
|
93 |
+
audio_data, _ = sf.read(temp_file.name)
|
94 |
+
inputs = whisper_processor(audio_data, sampling_rate=16000, return_tensors="pt")
|
95 |
+
predicted_ids = whisper_model.generate(inputs["input_features"])
|
96 |
+
transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True)
|
97 |
+
|
98 |
+
return {"transcribed_text": transcription}
|
99 |
+
except Exception as e:
|
100 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
101 |
|
|
|
102 |
@app.post("/process")
|
103 |
+
async def transcribe_and_translate_audio(
|
104 |
+
audio: UploadFile = File(...),
|
105 |
+
target_lang: str = Form(...)
|
106 |
+
):
|
107 |
try:
|
108 |
+
if target_lang not in translation_models:
|
109 |
+
return {"error": "Unsupported language."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
# Save uploaded file
|
112 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
113 |
+
temp_file.write(await audio.read())
|
114 |
+
temp_file.close()
|
115 |
+
|
116 |
+
# Transcribe
|
117 |
+
audio_data, _ = sf.read(temp_file.name)
|
118 |
+
inputs = whisper_processor(audio_data, sampling_rate=16000, return_tensors="pt")
|
119 |
+
predicted_ids = whisper_model.generate(inputs["input_features"])
|
120 |
+
transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True)
|
121 |
+
|
122 |
+
# Translate
|
123 |
+
model_name = translation_models[target_lang]
|
124 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
125 |
+
model = MarianMTModel.from_pretrained(model_name)
|
126 |
+
tokens = tokenizer(transcription, return_tensors="pt", padding=True)
|
127 |
+
translated_ids = model.generate(**tokens)
|
128 |
+
translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
129 |
+
|
130 |
+
return {
|
131 |
+
"transcribed_text": transcription,
|
132 |
+
"translated_text": translated_text
|
133 |
+
}
|
134 |
+
except Exception as e:
|
135 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|