Spaces:
Running
Running
import time | |
from PIL import Image | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
from starlette.middleware.base import BaseHTTPMiddleware | |
from pydantic import BaseModel, Field | |
from transformers import MBartForConditionalGeneration, MBartTokenizerFast | |
import torch | |
import requests, json, base64 | |
from io import BytesIO | |
from typing import List, Optional, Tuple | |
import os | |
import warnings | |
import logging | |
# Suppress specific FutureWarning from huggingface_hub | |
warnings.filterwarnings( | |
"ignore", | |
category=FutureWarning, | |
module="huggingface_hub.file_download" | |
) | |
# Initialize FastAPI app with versioning | |
app = FastAPI( | |
title="Text-to-Pictogram API", | |
version="1.0.0", | |
description="An API for converting text to pictograms, supporting English, French, and Tamil.", | |
) | |
# Set up custom logging format | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
# Custom Middleware for request time logging | |
class RequestTimingMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request, call_next): | |
start_time = time.time() | |
logger.info(f"Request received: {request.method} {request.url}") | |
# Process the request | |
response = await call_next(request) | |
end_time = time.time() | |
processing_time = end_time - start_time | |
# Log the timing | |
logger.info(f"Request processed: {request.method} {request.url}") | |
logger.info(f"Processing time: {processing_time:.4f}s") | |
# Return the response | |
return response | |
# Add the middleware to the app | |
app.add_middleware(RequestTimingMiddleware) | |
COLORS = { | |
"white": (255, 255, 255), | |
"black": (0, 0, 0), | |
"red": (255, 0, 0), | |
"green": (0, 255, 0), | |
"blue": (0, 0, 255), | |
"yellow": (255, 255, 0), | |
"cyan": (0, 255, 255), | |
"magenta": (255, 0, 255), | |
"gray": (128, 128, 128), | |
"orange": (255, 165, 0), | |
"purple": (128, 0, 128), | |
"brown": (165, 42, 42), | |
"pink": (255, 192, 203), | |
"lime": (0, 255, 0), | |
"teal": (0, 128, 128), | |
"navy": (0, 0, 128) | |
# Add more colors as needed | |
} | |
# Define schemas for requests and responses | |
class TranslationRequest(BaseModel): | |
src: str = Field(..., description="Source text to be translated.") | |
language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.") | |
class TranslationResponse(BaseModel): | |
language: str = Field(..., description="Language of the source text.") | |
src: str = Field(..., description="Source text in the original language.") | |
tgt: Optional[str] = Field(None, description="Translated text in the original language.") | |
pictogram_ids: Optional[List[Optional[int]]] = Field(None, description="List of pictogram IDs corresponding to the translation.") | |
image_base64: Optional[str] = Field(None, description="Base64-encoded image of the pictograms, if generated.") | |
# Load the model and tokenizer | |
model_path = "feedlight42/mbart25-text2picto" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = MBartForConditionalGeneration.from_pretrained(model_path) | |
tokenizer = MBartTokenizerFast.from_pretrained(model_path) | |
model = model.to(device) | |
# Folder to store local pictogram images | |
pictogram_folder = 'pictogram_images' | |
# Load the pictogram dictionary from the JSON file | |
with open('pictogram_vocab.json', 'r') as f: | |
pictogram_dict = json.load(f) | |
# Function to fetch a pictogram image from the ARASAAC API, or use local file if present | |
def fetch_pictogram(picto_id: int): | |
# Check if the pictogram image exists locally | |
image_path = os.path.join(pictogram_folder, f"{picto_id}.png") | |
if os.path.exists(image_path): | |
return Image.open(image_path) | |
# If the image is not local, request from ARASAAC API | |
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}" | |
response = requests.get(url) | |
if response.status_code == 200: | |
# Save the image locally | |
img = Image.open(BytesIO(response.content)) | |
img.save(image_path) # Save the image to the local folder | |
return img | |
return None | |
# Generate an image from a sequence of pictogram IDs with a customizable background color | |
def create_pictogram_image(pictogram_ids, background_color=(255, 0, 0)): # Default to red background | |
pictogram_images = [] | |
for picto_id in pictogram_ids: | |
picto_image = fetch_pictogram(picto_id) | |
if picto_image: | |
pictogram_images.append(picto_image) | |
# Concatenate all pictogram images | |
widths, heights = zip(*(i.size for i in pictogram_images)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
# Create a new image with the specified background color (RGBA with alpha channel) | |
final_image = Image.new('RGBA', (total_width, max_height), background_color + (255,)) # RGBA with alpha channel | |
x_offset = 0 | |
for img in pictogram_images: | |
img = img.convert("RGBA") # Ensure the image has an alpha channel | |
img_with_bg = Image.new("RGBA", img.size, background_color + (255,)) # Create a solid background | |
img_with_bg.paste(img, (0, 0), img) # Paste the image on top of the background | |
final_image.paste(img_with_bg, (x_offset, 0), img_with_bg) | |
x_offset += img.size[0] | |
return final_image | |
# Endpoint for health check | |
async def health_check(): | |
return {"status": "healthy", "message": "API is up and running"} | |
# Main translation endpoint | |
async def translate( | |
request: TranslationRequest, | |
backgroundColor: Optional[str] = Query("white"), | |
backgroundColorRGB: Optional[Tuple[int, int, int]] = Query(None) | |
): | |
""" | |
Translate the provided source text into pictograms and return a corresponding image. | |
- **src**: The source text to be translated. | |
- **language**: The source language. Accepted values are 'en', 'fr', 'ta'. | |
- **backgroundColor**: (Optional) Background color for the pictogram image, specified by name (e.g., 'red'). Default is 'white'. | |
- **backgroundColorRGB**: (Optional) Background color for the image in RGB format (e.g., (255, 0, 0) for red). This overrides backgroundColor. | |
""" | |
# Ensure that both backgroundColor and backgroundColorRGB are not provided simultaneously | |
if backgroundColor and backgroundColorRGB: | |
raise HTTPException(status_code=400, detail="You cannot provide both backgroundColor and backgroundColorRGB at the same time.") | |
# Ensure that the given language is a valid one | |
if request.language not in ["en", "fr", "ta"]: | |
raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.") | |
# Default to white RGB (255, 255, 255) if neither backgroundColor nor backgroundColorRGB is provided | |
if backgroundColorRGB: background_color = backgroundColorRGB | |
elif backgroundColor: background_color = COLORS.get(backgroundColor.lower(), (255, 255, 255)) | |
else: background_color = (255, 255, 255) | |
# Temporary fix | |
# Placeholder for unsupported languages | |
if request.language in ["en", "ta"]: | |
return TranslationResponse( | |
language=request.language, | |
src=request.src, | |
tgt=None, | |
pictogram_ids=[], | |
image_base64=None, | |
) | |
# Translate using French model | |
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device) | |
translated_tokens = model.generate(**inputs) | |
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
# Map translated sentence to pictograms | |
words = tgt_sentence.split() | |
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words] | |
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id] | |
if pictogram_ids: | |
# Generate pictogram image | |
final_image = create_pictogram_image(pictogram_ids, background_color) | |
if final_image: | |
img_byte_arr = BytesIO() | |
final_image.save(img_byte_arr, format="PNG") | |
encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") | |
return TranslationResponse( | |
language=request.language, | |
src=request.src, | |
tgt=tgt_sentence, | |
pictogram_ids=pictogram_ids, | |
image_base64=encoded_image, | |
) | |
return TranslationResponse( | |
language=request.language, | |
src=request.src, | |
tgt=tgt_sentence, | |
pictogram_ids=pictogram_ids, | |
image_base64=None, | |
) |