thiomi-2411 / app.py
mutisya's picture
Update app.py
8c0dbee verified
raw
history blame
11.3 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi import HTTPException
import uvicorn
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
import ffmpeg
import io
import logging
from flores200_codes import flores_codes
import nltk
import librosa
import json
import soundfile as sf
import numpy as np
import base64
from PIL import Image
from io import BytesIO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces
import easyocr
import numpy as np
import cv2
import io
nltk.download("punkt")
nltk.download('punkt_tab')
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load Whisper model from Hugging Face
model_name = "openai/whisper-base"
device = 0 if torch.cuda.is_available() else -1 # Use GPU if available
# set up translation pipeline
def get_translation_pipeline(translation_model_path):
model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_path)
tokenizer = AutoTokenizer.from_pretrained(translation_model_path)
translation_pipeline = pipeline('translation', model=model, tokenizer=tokenizer, device=device)
return translation_pipeline
translator = get_translation_pipeline("mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4")
asr_config_settings = {}
asr_pipelines={}
asr_preload_languages=["eng"]
def load_asr_model(model_id):
model_pipeline = pipeline("automatic-speech-recognition", model=model_id, device=device)
return model_pipeline
def initialize_asr_pipelines(load_models=False):
global asr_config_settings
global asr_pipelines
with open(f"asr_models_config.json") as f:
asr_config_settings = json.loads(f.read())
# iterate through config languge entries and load model for each into a dictionary
for lang, lang_config in asr_config_settings.items():
if lang in asr_preload_languages or load_models:
asr_pipelines[lang] = load_asr_model(lang_config["model_repo"])
def ensure_asr_pipeline_loaded(lang_code):
global asr_config_settings
global asr_pipelines
if lang_code in asr_pipelines:
pipeline = asr_pipelines[lang_code]
else:
lang_config = asr_config_settings[lang_code]
asr_pipelines[lang_code] = load_asr_model(lang_config["model_repo"])
class RecognitionResponse(BaseModel):
text: str
@spaces.GPU
@app.post("/recognize")
async def recognize_audio(audio: UploadFile = File(...), language: str = Form("en")):
try:
# Read audio data
audio_bytes = await audio.read()
# Convert audio bytes to WAV format if needed
try:
input_audio = ffmpeg.input('pipe:0')
audio_data, _ = (
input_audio.output('pipe:1', format='wav')
.run(input=audio_bytes, capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
logger.error("FFmpeg error while converting audio data", exc_info=True)
raise HTTPException(status_code=400, detail="Invalid audio format")
# Run Whisper model on the audio
language = language.strip('\"')
ensure_asr_pipeline_loaded(language)
transcriber = asr_pipelines[language]
result = transcriber(audio_data, return_timestamps="word")
# Extract transcription text
transcription = result["text"].capitalize()
segments = result["chunks"]
# return RecognitionResponse(text=transcription, chunks=curr_chunks)
transcription_result = []
for segment in segments:
transcription_result.append({
"word": segment['text'],
"startTime": round(segment['timestamp'][0],1),
"endTime": round(segment['timestamp'][1],1)
})
return {"text":transcription, "chunks": transcription_result}
# return {"chunks": transcription_result}
except Exception as e:
logger.error("Unexpected error during transcription", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
class TranslationRequest(BaseModel):
text: str
sourceLanguage: str
targetLanguage: str
class TranslationResponse(BaseModel):
translatedText: str
@spaces.GPU
@app.post("/translate", response_model=TranslationResponse)
async def translate_text(request: TranslationRequest):
source_language = request.sourceLanguage
target_language = request.targetLanguage
text_to_translate = request.text
try:
src_lang = flores_codes[source_language]
tgt_lang = flores_codes[target_language]
translated_text = translator(text_to_translate, src_lang=src_lang,tgt_lang=tgt_lang)[0]['translation_text']
return TranslationResponse(translatedText=translated_text)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class TTSRequest(BaseModel):
language: str
text: str
class TTSResponse(BaseModel):
audioBytes: str
sampleRate: int
tts_config_settings = {}
tts_pipelines={}
tts_preload_languages=["kik"]
def load_tts_model(model_id):
model_pipeline = pipeline("text-to-speech", model=model_id, device=device)
return model_pipeline
def initialize_tts_pipelines(load_models=False):
global tts_config_settings
global tts_pipelines
with open(f"tts_models_config.json") as f:
tts_config_settings = json.loads(f.read())
for lang, lang_config in tts_config_settings.items():
if lang in tts_preload_languages or load_models:
tts_pipelines[lang] = load_tts_model(lang_config["model_repo"])
def ensure_tts_pipeline_loaded(lang_code):
global tts_config_settings
global tts_pipelines
if lang_code in tts_pipelines:
pipeline = tts_pipelines[lang_code]
else:
lang_config = tts_config_settings[lang_code]
tts_pipelines[lang_code] = load_tts_model(lang_config["model_repo"])
@spaces.GPU
@app.post("/text-to-speech", response_model=TTSResponse)
async def text_to_speech(request: TTSRequest):
"""
Convert the given text to speech and return the audio data in Base64 format.
"""
print(request)
text = request.text.strip()
language = request.language.strip()
if not text:
raise HTTPException(status_code=400, detail="Input text is empty")
try:
# Generate speech using the TTS pipeline
print("Generating speech...")
ensure_tts_pipeline_loaded(language)
tts_pipeline = tts_pipelines[language]
#audio = tts_pipeline(text, return_tensors=True)["waveform"]
result = tts_pipeline(text)
audio_tensor = result["audio"]
sample_rate = result.get("sampling_rate", 22050)
# Convert the tensor to numpy array
audio_16bit = audio_tensor.T
# Save the audio to a BytesIO buffer as a WAV file
buffer = io.BytesIO()
sf.write(buffer, audio_16bit, sample_rate, format="WAV", subtype="PCM_16")
buffer.seek(0)
# Encode the audio as Base64
audio_bytes = base64.b64encode(buffer.read()).decode("utf-8")
return TTSResponse(audioBytes=audio_bytes, sampleRate=sample_rate)
except Exception as e:
logger.error("Unexpected error during TTS ", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
ocr_model_name = "microsoft/trocr-large-printed"
ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name)
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name)
# Ensure we're using the appropriate device (GPU if available)
ocr_device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_model.to(ocr_device)
class OcrRequest(BaseModel):
imageBase64: str # Base64-encoded image
class OcrResponse(BaseModel):
text: str
def base64_to_image(base64_str: str) -> Image.Image:
"""Convert a base64 string to a PIL Image."""
try:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data)).convert("RGB")
except Exception as e:
raise ValueError("Invalid image data")
@spaces.GPU
@app.post("/ocr2", response_model=OcrResponse)
async def process_ocr2(image: UploadFile = File(...)):
try:
# Read the uploaded file
image_bytes = await image.read()
# Convert bytes to PIL Image
image = Image.open(BytesIO(image_bytes)).convert("RGB")
pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# Perform OCR using the model
generated_ids = ocr_model.generate(pixel_values)
text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Extracted text: "+ text)
return OcrResponse(text=text.strip())
except ValueError as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=500, detail="An error occurred during OCR processing")
@spaces.GPU
@app.post("/ocr", response_model=OcrResponse)
async def process_ocr(image: UploadFile = File(...)):
try:
# Validate image file
if not image.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image file
contents = await image.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
raise HTTPException(status_code=400, detail="Invalid image file")
# Perform OCR
results = reader.readtext(img)
# Process results
text_blocks = []
full_text = []
for bbox, text, confidence in results:
# Convert bbox coordinates to integers
bbox_int = [[int(coord) for coord in point] for point in bbox]
text_blocks.append(
TextBlock(
text=text,
confidence=float(confidence),
bbox=bbox_int
)
)
full_text.append(text)
text = " ".join(full_text)
print("Extracted text: "+ text)
return OcrResponse(text=text.strip())
except ValueError as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Unexpected error during OCR ", exc_info=True)
raise HTTPException(status_code=500, detail="An error occurred during OCR processing")
# Optional: Add a health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Run the FastAPI application
initialize_tts_pipelines(True)
initialize_asr_pipelines()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)