Spaces:
Paused
Paused
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi import HTTPException | |
import uvicorn | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import torch | |
import ffmpeg | |
import io | |
import logging | |
from flores200_codes import flores_codes | |
import nltk | |
import librosa | |
import json | |
import soundfile as sf | |
import numpy as np | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import torch | |
import spaces | |
import easyocr | |
import numpy as np | |
import cv2 | |
import io | |
import os | |
from typing import List | |
from fastapi.middleware.cors import CORSMiddleware | |
# Ensure EasyOCR uses a directory with write permissions | |
# os.environ["EASYOCR_CACHE_DIR"] = "/app/.EasyOCR" | |
nltk.download("punkt") | |
nltk.download('punkt_tab') | |
app = FastAPI() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize EasyOCR reader globally for better performance | |
# Global variable for the OCR reader | |
reader = None | |
#print("Loading EasyOCR model...") | |
#reader = easyocr.Reader(['en'], gpu=True) # Set gpu=True if your environment supports it | |
#print("Model loaded successfully.") | |
async def startup_event(): | |
"""Initialize EasyOCR during startup""" | |
global reader | |
try: | |
logger.info("Checking GPU availability...") | |
if torch.cuda.is_available(): | |
device = "cuda" | |
gpu = True | |
logger.info(f"GPU detected: {torch.cuda.get_device_name(0)}") | |
else: | |
device = "cpu" | |
gpu = False | |
logger.warning("No GPU detected, falling back to CPU") | |
logger.info("Initializing EasyOCR and downloading models...") | |
# Set download directory to ensure we know where models are stored | |
model_storage_directory = os.path.join(os.getcwd(), "models") | |
# model_storage_directory = "/app/.EasyOCR" | |
logger.info("Creating models folder") | |
os.makedirs(model_storage_directory, exist_ok=True) | |
logger.info("Created temporary folder") | |
# Download and initialize model | |
reader = easyocr.Reader( | |
['en'], | |
# model_storage_directory=model_storage_directory, | |
# download_enabled=True, # Force download even if model exists | |
gpu=gpu, # Enable GPU if available | |
#detector=True, # Use CUDA detector | |
#recognizer=True # Use CUDA recognizer | |
) | |
logger.info(f"Initialized the reader. Testing operation using sample image") | |
# Perform a small inference to ensure everything is loaded | |
sample_image = np.zeros((100, 100), dtype=np.uint8) | |
reader.readtext(sample_image, detail=0) | |
logger.info(f"EasyOCR initialization completed successfully using {device.upper()}") | |
except Exception as e: | |
logger.error(f"Failed to initialize EasyOCR: {str(e)}") | |
raise e | |
# Load Whisper model from Hugging Face | |
model_name = "openai/whisper-base" | |
device = 0 if torch.cuda.is_available() else -1 # Use GPU if available | |
# set up translation pipeline | |
def get_translation_pipeline(translation_model_path): | |
model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_path) | |
tokenizer = AutoTokenizer.from_pretrained(translation_model_path) | |
translation_pipeline = pipeline('translation', model=model, tokenizer=tokenizer, device=device) | |
return translation_pipeline | |
translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4") | |
asr_config_settings = {} | |
asr_pipelines={} | |
asr_preload_languages=["eng"] | |
def load_asr_model(model_id): | |
model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device) | |
return model_pipeline | |
def initialize_asr_pipelines(load_models=False): | |
global asr_config_settings | |
global asr_pipelines | |
with open(f"asr_models_config.json") as f: | |
asr_config_settings = json.loads(f.read()) | |
# iterate through config languge entries and load model for each into a dictionary | |
for lang, lang_config in asr_config_settings.items(): | |
if lang in asr_preload_languages or load_models: | |
asr_pipelines[lang] = load_asr_model(lang_config["model_repo"]) | |
def ensure_asr_pipeline_loaded(lang_code): | |
global asr_config_settings | |
global asr_pipelines | |
if lang_code in asr_pipelines: | |
pipeline = asr_pipelines[lang_code] | |
else: | |
lang_config = asr_config_settings[lang_code] | |
asr_pipelines[lang_code] = load_asr_model(lang_config["model_repo"]) | |
class RecognitionResponse(BaseModel): | |
text: str | |
async def recognize_audio(audio: UploadFile = File(...), language: str = Form("en")): | |
try: | |
# Read audio data | |
audio_bytes = await audio.read() | |
# Convert audio bytes to WAV format if needed | |
try: | |
input_audio = ffmpeg.input('pipe:0') | |
audio_data, _ = ( | |
input_audio.output('pipe:1', format='wav') | |
.run(input=audio_bytes, capture_stdout=True, capture_stderr=True) | |
) | |
except ffmpeg.Error as e: | |
logger.error("FFmpeg error while converting audio data", exc_info=True) | |
raise HTTPException(status_code=400, detail="Invalid audio format") | |
# Run Whisper model on the audio | |
language = language.strip('\"') | |
ensure_asr_pipeline_loaded(language) | |
transcriber = asr_pipelines[language] | |
result = transcriber(audio_data, return_timestamps="word") | |
# Extract transcription text | |
transcription = result["text"].capitalize() | |
segments = result["chunks"] | |
# return RecognitionResponse(text=transcription, chunks=curr_chunks) | |
transcription_result = [] | |
for segment in segments: | |
transcription_result.append({ | |
"word": segment['text'], | |
"startTime": round(segment['timestamp'][0],1), | |
"endTime": round(segment['timestamp'][1],1) | |
}) | |
return {"text":transcription, "chunks": transcription_result} | |
# return {"chunks": transcription_result} | |
except Exception as e: | |
logger.error("Unexpected error during transcription", exc_info=True) | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
class TranslationRequest(BaseModel): | |
text: str | |
sourceLanguage: str | |
targetLanguage: str | |
class TranslationResponse(BaseModel): | |
translatedText: str | |
async def translate_text(request: TranslationRequest): | |
source_language = request.sourceLanguage | |
target_language = request.targetLanguage | |
text_to_translate = request.text | |
try: | |
src_lang = flores_codes[source_language] | |
tgt_lang = flores_codes[target_language] | |
translated_text = translator(text_to_translate, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text'] | |
return TranslationResponse(translatedText=translated_text) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
class TTSRequest(BaseModel): | |
language: str | |
text: str | |
class TTSResponse(BaseModel): | |
audioBytes: str | |
sampleRate: int | |
tts_config_settings = {} | |
tts_pipelines={} | |
tts_preload_languages=["kik"] | |
def load_tts_model(model_id): | |
model_pipeline = pipeline("text-to-speech", model=model_id, device=device) | |
return model_pipeline | |
def initialize_tts_pipelines(load_models=False): | |
global tts_config_settings | |
global tts_pipelines | |
with open(f"tts_models_config.json") as f: | |
tts_config_settings = json.loads(f.read()) | |
for lang, lang_config in tts_config_settings.items(): | |
if lang in tts_preload_languages or load_models: | |
tts_pipelines[lang] = load_tts_model(lang_config["model_repo"]) | |
def ensure_tts_pipeline_loaded(lang_code): | |
global tts_config_settings | |
global tts_pipelines | |
if lang_code in tts_pipelines: | |
pipeline = tts_pipelines[lang_code] | |
else: | |
lang_config = tts_config_settings[lang_code] | |
tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"]) | |
async def text_to_speech(request: TTSRequest): | |
""" | |
Convert the given text to speech and return the audio data in Base64 format. | |
""" | |
print(request) | |
text = request.text.strip() | |
language = request.language.strip() | |
if not text: | |
raise HTTPException(status_code=400, detail="Input text is empty") | |
try: | |
# Generate speech using the TTS pipeline | |
print("Generating speech...") | |
ensure_tts_pipeline_loaded(language) | |
tts_pipeline = tts_pipelines[language] | |
#audio = tts_pipeline(text, return_tensors=True)["waveform"] | |
result = tts_pipeline(text) | |
audio_tensor = result["audio"] | |
sample_rate = result.get("sampling_rate", 22050) | |
# Convert the tensor to numpy array | |
audio_16bit = audio_tensor.T | |
# Save the audio to a BytesIO buffer as a WAV file | |
buffer = io.BytesIO() | |
sf.write(buffer, audio_16bit, sample_rate, format="WAV", subtype="PCM_16") | |
buffer.seek(0) | |
# Encode the audio as Base64 | |
audio_bytes = base64.b64encode(buffer.read()).decode("utf-8") | |
return TTSResponse(audioBytes=audio_bytes, sampleRate=sample_rate) | |
except Exception as e: | |
logger.error("Unexpected error during TTS ", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") | |
""" | |
ocr_model_name = "microsoft/trocr-large-printed" | |
ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name) | |
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name) | |
# Ensure we're using the appropriate device (GPU if available) | |
ocr_device = "cuda" if torch.cuda.is_available() else "cpu" | |
ocr_model.to(ocr_device) | |
""" | |
class OcrRequest(BaseModel): | |
imageBase64: str # Base64-encoded image | |
class OcrResponse(BaseModel): | |
text: str | |
def base64_to_image(base64_str: str) -> Image.Image: | |
"""Convert a base64 string to a PIL Image.""" | |
try: | |
image_data = base64.b64decode(base64_str) | |
return Image.open(BytesIO(image_data)).convert("RGB") | |
except Exception as e: | |
raise ValueError("Invalid image data") | |
async def process_ocr2(image: UploadFile = File(...)): | |
try: | |
# Read the uploaded file | |
image_bytes = await image.read() | |
# Convert bytes to PIL Image | |
image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
# Perform OCR using the model | |
generated_ids = ocr_model.generate(pixel_values) | |
text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
print("Extracted text: "+ text) | |
return OcrResponse(text=text.strip()) | |
except ValueError as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=500, detail="An error occurred during OCR processing") | |
async def process_ocr(image: UploadFile = File(...)): | |
if reader is None: | |
raise HTTPException(status_code=500, detail="OCR system not initialized") | |
try: | |
# Validate image file | |
if not image.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="File must be an image") | |
# Read image file | |
contents = await image.read() | |
nparr = np.frombuffer(contents, np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if img is None: | |
raise HTTPException(status_code=400, detail="Invalid image file") | |
# Perform OCR | |
results = reader.readtext(img, detail=0) | |
text = results | |
print("Extracted text: "+ text) | |
return OcrResponse(text=text.strip()) | |
except ValueError as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=500, detail="An error occurred during OCR processing") | |
# Optional: Add a health check endpoint | |
async def health_check(): | |
return {"status": "healthy"} | |
# Run the FastAPI application | |
initialize_tts_pipelines(True) | |
initialize_asr_pipelines() | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |