Aswinthmani commited on
Commit
8c83892
Β·
verified Β·
1 Parent(s): fbd661f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -86
main.py CHANGED
@@ -1,28 +1,25 @@
1
  from fastapi import FastAPI, File, UploadFile, Form
 
2
  from fastapi.responses import JSONResponse
3
- from enum import Enum
4
- from transformers import pipeline, MarianMTModel, MarianTokenizer
5
- import shutil
6
- import os
7
- import uuid
8
- from googletrans import Translator
9
 
10
  app = FastAPI()
11
 
12
- # 🌍 Language Enum for dropdown in Swagger
13
- class LanguageEnum(str, Enum):
14
- ta = "ta" # Tamil
15
- fr = "fr" # French
16
- es = "es" # Spanish
17
- de = "de" # German
18
- it = "it" # Italian
19
- hi = "hi" # Hindi
20
- ru = "ru" # Russian
21
- zh = "zh" # Chinese
22
- ar = "ar" # Arabic
23
-
24
- # 🌐 Map target language to translation model
25
- model_map = {
26
  "fr": "Helsinki-NLP/opus-mt-en-fr",
27
  "es": "Helsinki-NLP/opus-mt-en-es",
28
  "de": "Helsinki-NLP/opus-mt-en-de",
@@ -31,78 +28,108 @@ model_map = {
31
  "ru": "Helsinki-NLP/opus-mt-en-ru",
32
  "zh": "Helsinki-NLP/opus-mt-en-zh",
33
  "ar": "Helsinki-NLP/opus-mt-en-ar",
34
- "ta": "gsarti/opus-mt-en-ta"
35
  }
36
 
37
- def translate_text(text, target_lang):
38
- if target_lang == "ta":
39
- try:
40
- translator = Translator()
41
- result = translator.translate(text, dest="ta")
42
- return result.text
43
- except Exception as e:
44
- return f"Google Translate failed: {str(e)}"
45
-
46
- if target_lang not in model_map:
47
- return f"No model for language: {target_lang}"
48
-
49
- model_name = model_map[target_lang]
50
- tokenizer = MarianTokenizer.from_pretrained(model_name)
51
- model = MarianMTModel.from_pretrained(model_name)
52
- encoded = tokenizer([text], return_tensors="pt", padding=True)
53
- translated = model.generate(**encoded)
54
- return tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
55
-
56
- # 🧠 Generate a random English sentence
57
- def generate_random_sentence(prompt="Daily conversation", max_length=30):
58
- generator = pipeline("text-generation", model="distilgpt2")
59
- result = generator(prompt, max_length=max_length, num_return_sequences=1)
60
- return result[0]["generated_text"].strip()
61
-
62
- # 🎀 Transcription endpoint
63
- @app.post("/transcribe")
64
- async def transcribe(audio: UploadFile = File(...)):
65
- temp_filename = f"temp_{uuid.uuid4().hex}.wav"
66
- with open(temp_filename, "wb") as f:
67
- shutil.copyfileobj(audio.file, f)
68
  try:
69
- asr = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
70
- result = asr(temp_filename)
71
- return JSONResponse(content={"transcribed_text": result["text"]})
72
- finally:
73
- os.remove(temp_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # 🌍 Translation endpoint
76
  @app.post("/translate")
77
- async def translate(text: str = Form(...), target_lang: LanguageEnum = Form(...)):
78
- translated = translate_text(text, target_lang.value)
79
- return JSONResponse(content={"translated_text": translated})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # πŸ” Combined endpoint (speech-to-translation)
82
  @app.post("/process")
83
- async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = Form(...)):
84
- temp_filename = f"temp_{uuid.uuid4().hex}.wav"
85
- with open(temp_filename, "wb") as f:
86
- shutil.copyfileobj(audio.file, f)
87
  try:
88
- asr = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
89
- result = asr(temp_filename)
90
- transcribed_text = result["text"]
91
- translated_text = translate_text(transcribed_text, target_lang.value)
92
- return JSONResponse(content={
93
- "transcribed_text": transcribed_text,
94
- "translated_text": translated_text
95
- })
96
- finally:
97
- os.remove(temp_filename)
98
 
99
- # ✨ Generate + Translate endpoint
100
- @app.get("/generate")
101
- def generate(prompt: str = "Daily conversation", target_lang: LanguageEnum = LanguageEnum.it):
102
- english = generate_random_sentence(prompt)
103
- translated = translate_text(english, target_lang.value)
104
- return {
105
- "prompt": prompt,
106
- "english": english,
107
- "translated": translated
108
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
+ from pydantic import BaseModel
5
+ from transformers import pipeline, MarianMTModel, MarianTokenizer, WhisperProcessor, WhisperForConditionalGeneration
6
+ import torch
7
+ import tempfile
8
+ import soundfile as sf
 
9
 
10
  app = FastAPI()
11
 
12
+ # Allow frontend to call backend
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # Supported languages
22
+ translation_models = {
 
 
 
23
  "fr": "Helsinki-NLP/opus-mt-en-fr",
24
  "es": "Helsinki-NLP/opus-mt-en-es",
25
  "de": "Helsinki-NLP/opus-mt-en-de",
 
28
  "ru": "Helsinki-NLP/opus-mt-en-ru",
29
  "zh": "Helsinki-NLP/opus-mt-en-zh",
30
  "ar": "Helsinki-NLP/opus-mt-en-ar",
31
+ "ta": "Helsinki-NLP/opus-mt-en-ta"
32
  }
33
 
34
+ # Load models once
35
+ generator = pipeline("text-generation", model="gpt2")
36
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
37
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
38
+
39
+ @app.get("/")
40
+ def root():
41
+ return {"message": "Backend is live βœ…"}
42
+
43
+ @app.get("/generate")
44
+ def generate_and_translate(prompt: str, target_lang: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
+ if target_lang not in translation_models:
47
+ return {"error": "Unsupported language."}
48
+
49
+ # 1. Generate English sentence
50
+ result = generator(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"]
51
+ english_sentence = result.strip()
52
+
53
+ # 2. Translate
54
+ model_name = translation_models[target_lang]
55
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
56
+ model = MarianMTModel.from_pretrained(model_name)
57
+ tokens = tokenizer(english_sentence, return_tensors="pt", padding=True)
58
+ translated_ids = model.generate(**tokens)
59
+ translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
60
+
61
+ return {"english": english_sentence, "translated": translated_text}
62
+ except Exception as e:
63
+ return JSONResponse(status_code=500, content={"error": str(e)})
64
+
65
+ class TranslateRequest(BaseModel):
66
+ text: str
67
+ target_lang: str
68
 
 
69
  @app.post("/translate")
70
+ def translate_text(data: TranslateRequest):
71
+ try:
72
+ if data.target_lang not in translation_models:
73
+ return {"error": "Unsupported language."}
74
+
75
+ model_name = translation_models[data.target_lang]
76
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
77
+ model = MarianMTModel.from_pretrained(model_name)
78
+ tokens = tokenizer(data.text, return_tensors="pt", padding=True)
79
+ translated_ids = model.generate(**tokens)
80
+ translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
81
+
82
+ return {"translated_text": translated_text}
83
+ except Exception as e:
84
+ return JSONResponse(status_code=500, content={"error": str(e)})
85
+
86
+ @app.post("/transcribe")
87
+ async def transcribe_audio(audio: UploadFile = File(...)):
88
+ try:
89
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
90
+ temp_file.write(await audio.read())
91
+ temp_file.close()
92
+
93
+ audio_data, _ = sf.read(temp_file.name)
94
+ inputs = whisper_processor(audio_data, sampling_rate=16000, return_tensors="pt")
95
+ predicted_ids = whisper_model.generate(inputs["input_features"])
96
+ transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True)
97
+
98
+ return {"transcribed_text": transcription}
99
+ except Exception as e:
100
+ return JSONResponse(status_code=500, content={"error": str(e)})
101
 
 
102
  @app.post("/process")
103
+ async def transcribe_and_translate_audio(
104
+ audio: UploadFile = File(...),
105
+ target_lang: str = Form(...)
106
+ ):
107
  try:
108
+ if target_lang not in translation_models:
109
+ return {"error": "Unsupported language."}
 
 
 
 
 
 
 
 
110
 
111
+ # Save uploaded file
112
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
113
+ temp_file.write(await audio.read())
114
+ temp_file.close()
115
+
116
+ # Transcribe
117
+ audio_data, _ = sf.read(temp_file.name)
118
+ inputs = whisper_processor(audio_data, sampling_rate=16000, return_tensors="pt")
119
+ predicted_ids = whisper_model.generate(inputs["input_features"])
120
+ transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True)
121
+
122
+ # Translate
123
+ model_name = translation_models[target_lang]
124
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
125
+ model = MarianMTModel.from_pretrained(model_name)
126
+ tokens = tokenizer(transcription, return_tensors="pt", padding=True)
127
+ translated_ids = model.generate(**tokens)
128
+ translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
129
+
130
+ return {
131
+ "transcribed_text": transcription,
132
+ "translated_text": translated_text
133
+ }
134
+ except Exception as e:
135
+ return JSONResponse(status_code=500, content={"error": str(e)})