thiomi-2411 / app.py
mutisya's picture
Update app.py
7f0d58a verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi import HTTPException
import uvicorn
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
import ffmpeg
import io
import logging
from flores200_codes import flores_codes
import nltk
import librosa
import json
import soundfile as sf
import numpy as np
import base64
from PIL import Image
from io import BytesIO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces
import easyocr
import numpy as np
import cv2
import io
import os
from typing import List
from fastapi.middleware.cors import CORSMiddleware
# Ensure EasyOCR uses a directory with write permissions
# os.environ["EASYOCR_CACHE_DIR"] = "/app/.EasyOCR"
nltk.download("punkt")
nltk.download('punkt_tab')
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize EasyOCR reader globally for better performance
# Global variable for the OCR reader
reader = None
#print("Loading EasyOCR model...")
#reader = easyocr.Reader(['en'], gpu=True) # Set gpu=True if your environment supports it
#print("Model loaded successfully.")
@app.on_event("startup")
async def startup_event():
"""Initialize EasyOCR during startup"""
global reader
try:
logger.info("Checking GPU availability...")
if torch.cuda.is_available():
device = "cuda"
gpu = True
logger.info(f"GPU detected: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
gpu = False
logger.warning("No GPU detected, falling back to CPU")
logger.info("Initializing EasyOCR and downloading models...")
# Set download directory to ensure we know where models are stored
model_storage_directory = os.path.join(os.getcwd(), "models")
# model_storage_directory = "/app/.EasyOCR"
logger.info("Creating models folder")
os.makedirs(model_storage_directory, exist_ok=True)
logger.info("Created temporary folder")
# Download and initialize model
reader = easyocr.Reader(
['en'],
# model_storage_directory=model_storage_directory,
# download_enabled=True, # Force download even if model exists
gpu=gpu, # Enable GPU if available
#detector=True, # Use CUDA detector
#recognizer=True # Use CUDA recognizer
)
logger.info(f"Initialized the reader. Testing operation using sample image")
# Perform a small inference to ensure everything is loaded
sample_image = np.zeros((100, 100), dtype=np.uint8)
reader.readtext(sample_image, detail=0)
logger.info(f"EasyOCR initialization completed successfully using {device.upper()}")
except Exception as e:
logger.error(f"Failed to initialize EasyOCR: {str(e)}")
raise e
# Load Whisper model from Hugging Face
model_name = "openai/whisper-base"
device = 0 if torch.cuda.is_available() else -1 # Use GPU if available
# set up translation pipeline
def get_translation_pipeline(translation_model_path):
model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_path)
tokenizer = AutoTokenizer.from_pretrained(translation_model_path)
translation_pipeline = pipeline('translation', model=model, tokenizer=tokenizer, device=device)
return translation_pipeline
translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4")
asr_config_settings = {}
asr_pipelines={}
asr_preload_languages=["eng"]
def load_asr_model(model_id):
model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device)
return model_pipeline
def initialize_asr_pipelines(load_models=False):
global asr_config_settings
global asr_pipelines
with open(f"asr_models_config.json") as f:
asr_config_settings = json.loads(f.read())
# iterate through config languge entries and load model for each into a dictionary
for lang, lang_config in asr_config_settings.items():
if lang in asr_preload_languages or load_models:
asr_pipelines[lang] = load_asr_model(lang_config["model_repo"])
def ensure_asr_pipeline_loaded(lang_code):
global asr_config_settings
global asr_pipelines
if lang_code in asr_pipelines:
pipeline = asr_pipelines[lang_code]
else:
lang_config = asr_config_settings[lang_code]
asr_pipelines[lang_code] = load_asr_model(lang_config["model_repo"])
class RecognitionResponse(BaseModel):
text: str
@spaces.GPU
@app.post("/recognize")
async def recognize_audio(audio: UploadFile = File(...), language: str = Form("en")):
try:
# Read audio data
audio_bytes = await audio.read()
# Convert audio bytes to WAV format if needed
try:
input_audio = ffmpeg.input('pipe:0')
audio_data, _ = (
input_audio.output('pipe:1', format='wav')
.run(input=audio_bytes, capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
logger.error("FFmpeg error while converting audio data", exc_info=True)
raise HTTPException(status_code=400, detail="Invalid audio format")
# Run Whisper model on the audio
language = language.strip('\"')
ensure_asr_pipeline_loaded(language)
transcriber = asr_pipelines[language]
result = transcriber(audio_data, return_timestamps="word")
# Extract transcription text
transcription = result["text"].capitalize()
segments = result["chunks"]
# return RecognitionResponse(text=transcription, chunks=curr_chunks)
transcription_result = []
for segment in segments:
transcription_result.append({
"word": segment['text'],
"startTime": round(segment['timestamp'][0],1),
"endTime": round(segment['timestamp'][1],1)
})
return {"text":transcription, "chunks": transcription_result}
# return {"chunks": transcription_result}
except Exception as e:
logger.error("Unexpected error during transcription", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
class TranslationRequest(BaseModel):
text: str
sourceLanguage: str
targetLanguage: str
class TranslationResponse(BaseModel):
translatedText: str
@spaces.GPU
@app.post("/translate", response_model=TranslationResponse)
async def translate_text(request: TranslationRequest):
source_language = request.sourceLanguage
target_language = request.targetLanguage
text_to_translate = request.text
try:
src_lang = flores_codes[source_language]
tgt_lang = flores_codes[target_language]
translated_text = translator(text_to_translate, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text']
return TranslationResponse(translatedText=translated_text)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class TTSRequest(BaseModel):
language: str
text: str
class TTSResponse(BaseModel):
audioBytes: str
sampleRate: int
tts_config_settings = {}
tts_pipelines={}
tts_preload_languages=["kik"]
def load_tts_model(model_id):
model_pipeline = pipeline("text-to-speech", model=model_id, device=device)
return model_pipeline
def initialize_tts_pipelines(load_models=False):
global tts_config_settings
global tts_pipelines
with open(f"tts_models_config.json") as f:
tts_config_settings = json.loads(f.read())
for lang, lang_config in tts_config_settings.items():
if lang in tts_preload_languages or load_models:
tts_pipelines[lang] = load_tts_model(lang_config["model_repo"])
def ensure_tts_pipeline_loaded(lang_code):
global tts_config_settings
global tts_pipelines
if lang_code in tts_pipelines:
pipeline = tts_pipelines[lang_code]
else:
lang_config = tts_config_settings[lang_code]
tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"])
@spaces.GPU
@app.post("/text-to-speech", response_model=TTSResponse)
async def text_to_speech(request: TTSRequest):
"""
Convert the given text to speech and return the audio data in Base64 format.
"""
print(request)
text = request.text.strip()
language = request.language.strip()
if not text:
raise HTTPException(status_code=400, detail="Input text is empty")
try:
# Generate speech using the TTS pipeline
print("Generating speech...")
ensure_tts_pipeline_loaded(language)
tts_pipeline = tts_pipelines[language]
#audio = tts_pipeline(text, return_tensors=True)["waveform"]
result = tts_pipeline(text)
audio_tensor = result["audio"]
sample_rate = result.get("sampling_rate", 22050)
# Convert the tensor to numpy array
audio_16bit = audio_tensor.T
# Save the audio to a BytesIO buffer as a WAV file
buffer = io.BytesIO()
sf.write(buffer, audio_16bit, sample_rate, format="WAV", subtype="PCM_16")
buffer.seek(0)
# Encode the audio as Base64
audio_bytes = base64.b64encode(buffer.read()).decode("utf-8")
return TTSResponse(audioBytes=audio_bytes, sampleRate=sample_rate)
except Exception as e:
logger.error("Unexpected error during TTS ", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
"""
ocr_model_name = "microsoft/trocr-large-printed"
ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name)
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name)
# Ensure we're using the appropriate device (GPU if available)
ocr_device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_model.to(ocr_device)
"""
class OcrRequest(BaseModel):
imageBase64: str # Base64-encoded image
class OcrResponse(BaseModel):
text: str
def base64_to_image(base64_str: str) -> Image.Image:
"""Convert a base64 string to a PIL Image."""
try:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data)).convert("RGB")
except Exception as e:
raise ValueError("Invalid image data")
@spaces.GPU
@app.post("/ocr2", response_model=OcrResponse)
async def process_ocr2(image: UploadFile = File(...)):
try:
# Read the uploaded file
image_bytes = await image.read()
# Convert bytes to PIL Image
image = Image.open(BytesIO(image_bytes)).convert("RGB")
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# Perform OCR using the model
generated_ids = ocr_model.generate(pixel_values)
text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Extracted text: "+ text)
return OcrResponse(text=text.strip())
except ValueError as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=500, detail="An error occurred during OCR processing")
@spaces.GPU
@app.post("/ocr", response_model=OcrResponse)
async def process_ocr(image: UploadFile = File(...)):
if reader is None:
raise HTTPException(status_code=500, detail="OCR system not initialized")
try:
# Validate image file
if not image.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image file
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid image file")
# Perform OCR
results = reader.readtext(img, detail=0)
text = results
print("Extracted text: "+ text)
return OcrResponse(text=text.strip())
except ValueError as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=500, detail="An error occurred during OCR processing")
# Optional: Add a health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Run the FastAPI application
initialize_tts_pipelines(True)
initialize_asr_pipelines()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)