Spaces:
Paused
Paused
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi import HTTPException | |
import uvicorn | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import torch | |
import ffmpeg | |
import io | |
import logging | |
from flores200_codes import flores_codes | |
import nltk | |
import librosa | |
import json | |
import soundfile as sf | |
import numpy as np | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import torch | |
import spaces | |
import easyocr | |
import numpy as np | |
import cv2 | |
import io | |
nltk.download("punkt") | |
nltk.download('punkt_tab') | |
app = FastAPI() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load Whisper model from Hugging Face | |
model_name = "openai/whisper-base" | |
device = 0 if torch.cuda.is_available() else -1 # Use GPU if available | |
# set up translation pipeline | |
def get_translation_pipeline(translation_model_path): | |
model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_path) | |
tokenizer = AutoTokenizer.from_pretrained(translation_model_path) | |
translation_pipeline = pipeline('translation', model=model, tokenizer=tokenizer, device=device) | |
return translation_pipeline | |
translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4") | |
asr_config_settings = {} | |
asr_pipelines={} | |
asr_preload_languages=["eng"] | |
def load_asr_model(model_id): | |
model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device) | |
return model_pipeline | |
def initialize_asr_pipelines(load_models=False): | |
global asr_config_settings | |
global asr_pipelines | |
with open(f"asr_models_config.json") as f: | |
asr_config_settings = json.loads(f.read()) | |
# iterate through config languge entries and load model for each into a dictionary | |
for lang, lang_config in asr_config_settings.items(): | |
if lang in asr_preload_languages or load_models: | |
asr_pipelines[lang] = load_asr_model(lang_config["model_repo"]) | |
def ensure_asr_pipeline_loaded(lang_code): | |
global asr_config_settings | |
global asr_pipelines | |
if lang_code in asr_pipelines: | |
pipeline = asr_pipelines[lang_code] | |
else: | |
lang_config = asr_config_settings[lang_code] | |
asr_pipelines[lang_code] = load_asr_model(lang_config["model_repo"]) | |
class RecognitionResponse(BaseModel): | |
text: str | |
async def recognize_audio(audio: UploadFile = File(...), language: str = Form("en")): | |
try: | |
# Read audio data | |
audio_bytes = await audio.read() | |
# Convert audio bytes to WAV format if needed | |
try: | |
input_audio = ffmpeg.input('pipe:0') | |
audio_data, _ = ( | |
input_audio.output('pipe:1', format='wav') | |
.run(input=audio_bytes, capture_stdout=True, capture_stderr=True) | |
) | |
except ffmpeg.Error as e: | |
logger.error("FFmpeg error while converting audio data", exc_info=True) | |
raise HTTPException(status_code=400, detail="Invalid audio format") | |
# Run Whisper model on the audio | |
language = language.strip('\"') | |
ensure_asr_pipeline_loaded(language) | |
transcriber = asr_pipelines[language] | |
result = transcriber(audio_data, return_timestamps="word") | |
# Extract transcription text | |
transcription = result["text"].capitalize() | |
segments = result["chunks"] | |
# return RecognitionResponse(text=transcription, chunks=curr_chunks) | |
transcription_result = [] | |
for segment in segments: | |
transcription_result.append({ | |
"word": segment['text'], | |
"startTime": round(segment['timestamp'][0],1), | |
"endTime": round(segment['timestamp'][1],1) | |
}) | |
return {"text":transcription, "chunks": transcription_result} | |
# return {"chunks": transcription_result} | |
except Exception as e: | |
logger.error("Unexpected error during transcription", exc_info=True) | |
raise HTTPException(status_code=500, detail="Internal Server Error") | |
class TranslationRequest(BaseModel): | |
text: str | |
sourceLanguage: str | |
targetLanguage: str | |
class TranslationResponse(BaseModel): | |
translatedText: str | |
async def translate_text(request: TranslationRequest): | |
source_language = request.sourceLanguage | |
target_language = request.targetLanguage | |
text_to_translate = request.text | |
try: | |
src_lang = flores_codes[source_language] | |
tgt_lang = flores_codes[target_language] | |
translated_text = translator(text_to_translate, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text'] | |
return TranslationResponse(translatedText=translated_text) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
class TTSRequest(BaseModel): | |
language: str | |
text: str | |
class TTSResponse(BaseModel): | |
audioBytes: str | |
sampleRate: int | |
tts_config_settings = {} | |
tts_pipelines={} | |
tts_preload_languages=["kik"] | |
def load_tts_model(model_id): | |
model_pipeline = pipeline("text-to-speech", model=model_id, device=device) | |
return model_pipeline | |
def initialize_tts_pipelines(load_models=False): | |
global tts_config_settings | |
global tts_pipelines | |
with open(f"tts_models_config.json") as f: | |
tts_config_settings = json.loads(f.read()) | |
for lang, lang_config in tts_config_settings.items(): | |
if lang in tts_preload_languages or load_models: | |
tts_pipelines[lang] = load_tts_model(lang_config["model_repo"]) | |
def ensure_tts_pipeline_loaded(lang_code): | |
global tts_config_settings | |
global tts_pipelines | |
if lang_code in tts_pipelines: | |
pipeline = tts_pipelines[lang_code] | |
else: | |
lang_config = tts_config_settings[lang_code] | |
tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"]) | |
async def text_to_speech(request: TTSRequest): | |
""" | |
Convert the given text to speech and return the audio data in Base64 format. | |
""" | |
print(request) | |
text = request.text.strip() | |
language = request.language.strip() | |
if not text: | |
raise HTTPException(status_code=400, detail="Input text is empty") | |
try: | |
# Generate speech using the TTS pipeline | |
print("Generating speech...") | |
ensure_tts_pipeline_loaded(language) | |
tts_pipeline = tts_pipelines[language] | |
#audio = tts_pipeline(text, return_tensors=True)["waveform"] | |
result = tts_pipeline(text) | |
audio_tensor = result["audio"] | |
sample_rate = result.get("sampling_rate", 22050) | |
# Convert the tensor to numpy array | |
audio_16bit = audio_tensor.T | |
# Save the audio to a BytesIO buffer as a WAV file | |
buffer = io.BytesIO() | |
sf.write(buffer, audio_16bit, sample_rate, format="WAV", subtype="PCM_16") | |
buffer.seek(0) | |
# Encode the audio as Base64 | |
audio_bytes = base64.b64encode(buffer.read()).decode("utf-8") | |
return TTSResponse(audioBytes=audio_bytes, sampleRate=sample_rate) | |
except Exception as e: | |
logger.error("Unexpected error during TTS ", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}") | |
ocr_model_name = "microsoft/trocr-large-printed" | |
ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name) | |
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name) | |
# Ensure we're using the appropriate device (GPU if available) | |
ocr_device = "cuda" if torch.cuda.is_available() else "cpu" | |
ocr_model.to(ocr_device) | |
class OcrRequest(BaseModel): | |
imageBase64: str # Base64-encoded image | |
class OcrResponse(BaseModel): | |
text: str | |
def base64_to_image(base64_str: str) -> Image.Image: | |
"""Convert a base64 string to a PIL Image.""" | |
try: | |
image_data = base64.b64decode(base64_str) | |
return Image.open(BytesIO(image_data)).convert("RGB") | |
except Exception as e: | |
raise ValueError("Invalid image data") | |
async def process_ocr2(image: UploadFile = File(...)): | |
try: | |
# Read the uploaded file | |
image_bytes = await image.read() | |
# Convert bytes to PIL Image | |
image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
# Perform OCR using the model | |
generated_ids = ocr_model.generate(pixel_values) | |
text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
print("Extracted text: "+ text) | |
return OcrResponse(text=text.strip()) | |
except ValueError as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=500, detail="An error occurred during OCR processing") | |
async def process_ocr(image: UploadFile = File(...)): | |
try: | |
# Validate image file | |
if not image.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="File must be an image") | |
# Read image file | |
contents = await image.read() | |
nparr = np.frombuffer(contents, np.uint8) | |
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if img is None: | |
raise HTTPException(status_code=400, detail="Invalid image file") | |
# Perform OCR | |
results = reader.readtext(img) | |
# Process results | |
text_blocks = [] | |
full_text = [] | |
for bbox, text, confidence in results: | |
# Convert bbox coordinates to integers | |
bbox_int = [[int(coord) for coord in point] for point in bbox] | |
text_blocks.append( | |
TextBlock( | |
text=text, | |
confidence=float(confidence), | |
bbox=bbox_int | |
) | |
) | |
full_text.append(text) | |
text = " ".join(full_text) | |
print("Extracted text: "+ text) | |
return OcrResponse(text=text.strip()) | |
except ValueError as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
logger.error("Unexpected error during OCR ", exc_info=True) | |
raise HTTPException(status_code=500, detail="An error occurred during OCR processing") | |
# Optional: Add a health check endpoint | |
async def health_check(): | |
return {"status": "healthy"} | |
# Run the FastAPI application | |
initialize_tts_pipelines(True) | |
initialize_asr_pipelines() | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |