Aswinthmani commited on
Commit
00c2961
Β·
verified Β·
1 Parent(s): 11ff280

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -49
main.py CHANGED
@@ -1,33 +1,43 @@
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import JSONResponse
 
3
  from enum import Enum
4
  from transformers import pipeline, MarianMTModel, MarianTokenizer
5
  import shutil
6
  import os
7
  import uuid
8
- import uvicorn
9
- from googletrans import Translator
10
 
11
- os.environ["HF_HOME"] = "/app/.cache/huggingface"
 
 
12
  app = FastAPI()
13
 
14
- # 🎯 Hugging Face Pipelines
15
- asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
16
- generator_pipeline = pipeline("text-generation", model="distilgpt2")
 
 
 
 
 
 
 
 
 
17
 
18
- # 🌍 Language Enum for dropdown in Swagger
19
  class LanguageEnum(str, Enum):
20
- ta = "ta" # Tamil
21
- fr = "fr" # French
22
- es = "es" # Spanish
23
- de = "de" # German
24
- it = "it" # Italian
25
- hi = "hi" # Hindi
26
- ru = "ru" # Russian
27
- zh = "zh" # Chinese
28
- ar = "ar" # Arabic
29
 
30
- # 🌐 Map target language to translation model
31
  model_map = {
32
  "fr": "Helsinki-NLP/opus-mt-en-fr",
33
  "es": "Helsinki-NLP/opus-mt-en-es",
@@ -37,23 +47,12 @@ model_map = {
37
  "ru": "Helsinki-NLP/opus-mt-en-ru",
38
  "zh": "Helsinki-NLP/opus-mt-en-zh",
39
  "ar": "Helsinki-NLP/opus-mt-en-ar",
40
- "ta": "gsarti/opus-mt-en-ta"
41
  }
42
 
43
  def translate_text(text, target_lang):
44
- if target_lang == "ta":
45
- # Use Google Translate for Tamil
46
- try:
47
- translator = Translator()
48
- result = translator.translate(text, dest="ta")
49
- return result.text
50
- except Exception as e:
51
- return f"Google Translate failed: {str(e)}"
52
-
53
- # Use MarianMT for other supported languages
54
  if target_lang not in model_map:
55
  return f"No model for language: {target_lang}"
56
-
57
  model_name = model_map[target_lang]
58
  tokenizer = MarianTokenizer.from_pretrained(model_name)
59
  model = MarianMTModel.from_pretrained(model_name)
@@ -61,38 +60,29 @@ def translate_text(text, target_lang):
61
  translated = model.generate(**encoded)
62
  return tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
63
 
64
-
65
- # 🧠 Generate a random English sentence
66
- def generate_random_sentence(prompt="Daily conversation", max_length=30):
67
- result = generator_pipeline(prompt, max_length=max_length, num_return_sequences=1)
68
- return result[0]["generated_text"].strip()
69
-
70
- # 🎀 Transcription endpoint
71
  @app.post("/transcribe")
72
  async def transcribe(audio: UploadFile = File(...)):
73
- temp_filename = f"temp_{uuid.uuid4().hex}.wav"
74
- with open(temp_filename, "wb") as f:
75
  shutil.copyfileobj(audio.file, f)
76
  try:
77
- result = asr_pipeline(temp_filename)
78
  return JSONResponse(content={"transcribed_text": result["text"]})
79
  finally:
80
- os.remove(temp_filename)
81
 
82
- # 🌍 Translation endpoint
83
  @app.post("/translate")
84
  async def translate(text: str = Form(...), target_lang: LanguageEnum = Form(...)):
85
  translated = translate_text(text, target_lang.value)
86
  return JSONResponse(content={"translated_text": translated})
87
 
88
- # πŸ” Combined endpoint (speech-to-translation)
89
  @app.post("/process")
90
  async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = Form(...)):
91
- temp_filename = f"temp_{uuid.uuid4().hex}.wav"
92
- with open(temp_filename, "wb") as f:
93
  shutil.copyfileobj(audio.file, f)
94
  try:
95
- result = asr_pipeline(temp_filename)
96
  transcribed_text = result["text"]
97
  translated_text = translate_text(transcribed_text, target_lang.value)
98
  return JSONResponse(content={
@@ -100,15 +90,18 @@ async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = For
100
  "translated_text": translated_text
101
  })
102
  finally:
103
- os.remove(temp_filename)
104
 
105
- # ✨ Generate + Translate endpoint
106
  @app.get("/generate")
107
- def generate(prompt: str = "Daily conversation", target_lang: LanguageEnum = LanguageEnum.it):
108
- english = generate_random_sentence(prompt)
109
  translated = translate_text(english, target_lang.value)
110
  return {
111
  "prompt": prompt,
112
  "english": english,
113
  "translated": translated
114
  }
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from enum import Enum
5
  from transformers import pipeline, MarianMTModel, MarianTokenizer
6
  import shutil
7
  import os
8
  import uuid
 
 
9
 
10
+ # Set Hugging Face cache directory (essential for Hugging Face Spaces)
11
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
12
+
13
  app = FastAPI()
14
 
15
+ # CORS for frontend
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # βœ… Use smaller model to avoid timeout
25
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
26
+ generator_pipeline = pipeline("text-generation", model="sshleifer/tiny-gpt2")
27
 
28
+ # Supported languages (dropdown in Swagger UI)
29
  class LanguageEnum(str, Enum):
30
+ ta = "ta"
31
+ fr = "fr"
32
+ es = "es"
33
+ de = "de"
34
+ it = "it"
35
+ hi = "hi"
36
+ ru = "ru"
37
+ zh = "zh"
38
+ ar = "ar"
39
 
40
+ # Language model mapping
41
  model_map = {
42
  "fr": "Helsinki-NLP/opus-mt-en-fr",
43
  "es": "Helsinki-NLP/opus-mt-en-es",
 
47
  "ru": "Helsinki-NLP/opus-mt-en-ru",
48
  "zh": "Helsinki-NLP/opus-mt-en-zh",
49
  "ar": "Helsinki-NLP/opus-mt-en-ar",
50
+ "ta": "Helsinki-NLP/opus-mt-en-ta" # Changed from gsarti to Helsinki version
51
  }
52
 
53
  def translate_text(text, target_lang):
 
 
 
 
 
 
 
 
 
 
54
  if target_lang not in model_map:
55
  return f"No model for language: {target_lang}"
 
56
  model_name = model_map[target_lang]
57
  tokenizer = MarianTokenizer.from_pretrained(model_name)
58
  model = MarianMTModel.from_pretrained(model_name)
 
60
  translated = model.generate(**encoded)
61
  return tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
62
 
 
 
 
 
 
 
 
63
  @app.post("/transcribe")
64
  async def transcribe(audio: UploadFile = File(...)):
65
+ temp_file = f"temp_{uuid.uuid4().hex}.wav"
66
+ with open(temp_file, "wb") as f:
67
  shutil.copyfileobj(audio.file, f)
68
  try:
69
+ result = asr_pipeline(temp_file)
70
  return JSONResponse(content={"transcribed_text": result["text"]})
71
  finally:
72
+ os.remove(temp_file)
73
 
 
74
  @app.post("/translate")
75
  async def translate(text: str = Form(...), target_lang: LanguageEnum = Form(...)):
76
  translated = translate_text(text, target_lang.value)
77
  return JSONResponse(content={"translated_text": translated})
78
 
 
79
  @app.post("/process")
80
  async def process(audio: UploadFile = File(...), target_lang: LanguageEnum = Form(...)):
81
+ temp_file = f"temp_{uuid.uuid4().hex}.wav"
82
+ with open(temp_file, "wb") as f:
83
  shutil.copyfileobj(audio.file, f)
84
  try:
85
+ result = asr_pipeline(temp_file)
86
  transcribed_text = result["text"]
87
  translated_text = translate_text(transcribed_text, target_lang.value)
88
  return JSONResponse(content={
 
90
  "translated_text": translated_text
91
  })
92
  finally:
93
+ os.remove(temp_file)
94
 
 
95
  @app.get("/generate")
96
+ def generate(prompt: str = "Daily conversation", target_lang: LanguageEnum = LanguageEnum.fr):
97
+ english = generator_pipeline(prompt, max_length=30, num_return_sequences=1)[0]["generated_text"].strip()
98
  translated = translate_text(english, target_lang.value)
99
  return {
100
  "prompt": prompt,
101
  "english": english,
102
  "translated": translated
103
  }
104
+
105
+ @app.get("/")
106
+ def root():
107
+ return {"message": "βœ… Backend is live!"}