import time from PIL import Image from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel, Field from transformers import MBartForConditionalGeneration, MBartTokenizerFast import torch import requests, json, base64 from io import BytesIO from typing import List, Optional, Tuple import os import warnings import logging # Suppress specific FutureWarning from huggingface_hub warnings.filterwarnings( "ignore", category=FutureWarning, module="huggingface_hub.file_download" ) # Initialize FastAPI app with versioning app = FastAPI( title="Text-to-Pictogram API", version="1.0.0", description="An API for converting text to pictograms, supporting English, French, and Tamil.", ) # Set up custom logging format logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) # Custom Middleware for request time logging class RequestTimingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): start_time = time.time() logger.info(f"Request received: {request.method} {request.url}") # Process the request response = await call_next(request) end_time = time.time() processing_time = end_time - start_time # Log the timing logger.info(f"Request processed: {request.method} {request.url}") logger.info(f"Processing time: {processing_time:.4f}s") # Return the response return response # Add the middleware to the app app.add_middleware(RequestTimingMiddleware) COLORS = { "white": (255, 255, 255), "black": (0, 0, 0), "red": (255, 0, 0), "green": (0, 255, 0), "blue": (0, 0, 255), "yellow": (255, 255, 0), "cyan": (0, 255, 255), "magenta": (255, 0, 255), "gray": (128, 128, 128), "orange": (255, 165, 0), "purple": (128, 0, 128), "brown": (165, 42, 42), "pink": (255, 192, 203), "lime": (0, 255, 0), "teal": (0, 128, 128), "navy": (0, 0, 128) # Add more colors as needed } # Define schemas for requests and responses class TranslationRequest(BaseModel): src: str = Field(..., description="Source text to be translated.") language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.") class TranslationResponse(BaseModel): language: str = Field(..., description="Language of the source text.") src: str = Field(..., description="Source text in the original language.") tgt: Optional[str] = Field(None, description="Translated text in the original language.") pictogram_ids: Optional[List[Optional[int]]] = Field(None, description="List of pictogram IDs corresponding to the translation.") image_base64: Optional[str] = Field(None, description="Base64-encoded image of the pictograms, if generated.") # Load the model and tokenizer model_path = "feedlight42/mbart25-text2picto" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MBartForConditionalGeneration.from_pretrained(model_path) tokenizer = MBartTokenizerFast.from_pretrained(model_path) model = model.to(device) # Folder to store local pictogram images pictogram_folder = 'pictogram_images' # Load the pictogram dictionary from the JSON file with open('pictogram_vocab.json', 'r') as f: pictogram_dict = json.load(f) # Function to fetch a pictogram image from the ARASAAC API, or use local file if present def fetch_pictogram(picto_id: int): # Check if the pictogram image exists locally image_path = os.path.join(pictogram_folder, f"{picto_id}.png") if os.path.exists(image_path): return Image.open(image_path) # If the image is not local, request from ARASAAC API url = f"https://api.arasaac.org/v1/pictograms/{picto_id}" response = requests.get(url) if response.status_code == 200: # Save the image locally img = Image.open(BytesIO(response.content)) img.save(image_path) # Save the image to the local folder return img return None # Generate an image from a sequence of pictogram IDs with a customizable background color def create_pictogram_image(pictogram_ids, background_color=(255, 0, 0)): # Default to red background pictogram_images = [] for picto_id in pictogram_ids: picto_image = fetch_pictogram(picto_id) if picto_image: pictogram_images.append(picto_image) # Concatenate all pictogram images widths, heights = zip(*(i.size for i in pictogram_images)) total_width = sum(widths) max_height = max(heights) # Create a new image with the specified background color (RGBA with alpha channel) final_image = Image.new('RGBA', (total_width, max_height), background_color + (255,)) # RGBA with alpha channel x_offset = 0 for img in pictogram_images: img = img.convert("RGBA") # Ensure the image has an alpha channel img_with_bg = Image.new("RGBA", img.size, background_color + (255,)) # Create a solid background img_with_bg.paste(img, (0, 0), img) # Paste the image on top of the background final_image.paste(img_with_bg, (x_offset, 0), img_with_bg) x_offset += img.size[0] return final_image # Endpoint for health check @app.get("/") @app.get("/health", summary="Health Check", response_description="Health status") async def health_check(): return {"status": "healthy", "message": "API is up and running"} # Main translation endpoint @app.post("/v1/translate", summary="Translate Text to Pictograms", description="Translates text from a source language to a target language and converts the translation into pictograms. Optionally customize the background color of the generated pictogram images.") async def translate( request: TranslationRequest, backgroundColor: Optional[str] = Query("white"), backgroundColorRGB: Optional[Tuple[int, int, int]] = Query(None) ): """ Translate the provided source text into pictograms and return a corresponding image. - **src**: The source text to be translated. - **language**: The source language. Accepted values are 'en', 'fr', 'ta'. - **backgroundColor**: (Optional) Background color for the pictogram image, specified by name (e.g., 'red'). Default is 'white'. - **backgroundColorRGB**: (Optional) Background color for the image in RGB format (e.g., (255, 0, 0) for red). This overrides backgroundColor. """ # Ensure that both backgroundColor and backgroundColorRGB are not provided simultaneously if backgroundColor and backgroundColorRGB: raise HTTPException(status_code=400, detail="You cannot provide both backgroundColor and backgroundColorRGB at the same time.") # Ensure that the given language is a valid one if request.language not in ["en", "fr", "ta"]: raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.") # Default to white RGB (255, 255, 255) if neither backgroundColor nor backgroundColorRGB is provided if backgroundColorRGB: background_color = backgroundColorRGB elif backgroundColor: background_color = COLORS.get(backgroundColor.lower(), (255, 255, 255)) else: background_color = (255, 255, 255) # Temporary fix # Placeholder for unsupported languages if request.language in ["en", "ta"]: return TranslationResponse( language=request.language, src=request.src, tgt=None, pictogram_ids=[], image_base64=None, ) # Translate using French model inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device) translated_tokens = model.generate(**inputs) tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) # Map translated sentence to pictograms words = tgt_sentence.split() pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words] pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id] if pictogram_ids: # Generate pictogram image final_image = create_pictogram_image(pictogram_ids, background_color) if final_image: img_byte_arr = BytesIO() final_image.save(img_byte_arr, format="PNG") encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") return TranslationResponse( language=request.language, src=request.src, tgt=tgt_sentence, pictogram_ids=pictogram_ids, image_base64=encoded_image, ) return TranslationResponse( language=request.language, src=request.src, tgt=tgt_sentence, pictogram_ids=pictogram_ids, image_base64=None, )