from PIL import Image from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import MBartForConditionalGeneration, MBartTokenizerFast import torch import requests, json, base64 from io import BytesIO from typing import Optional # Initialize FastAPI app with versioning app = FastAPI( title="Text-to-Pictogram API", version="1.0.0", description="An API for converting text to pictograms, supporting English, French, and Tamil.", ) # Define schemas for requests and responses class TranslationRequest(BaseModel): src: str = Field(..., description="Source text to be translated.") language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.") class TranslationResponse(BaseModel): language: str src: str tgt: Optional[str] = None image_base64: Optional[str] = None # Load the model and tokenizer model_path = "feedlight42/mbart25-text2picto" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MBartForConditionalGeneration.from_pretrained(model_path) tokenizer = MBartTokenizerFast.from_pretrained(model_path) model = model.to(device) # Load the pictogram dictionary from the JSON file with open('pictogram_vocab.json', 'r') as f: pictogram_dict = json.load(f) # Fetch a pictogram image from the ARASAAC API def fetch_pictogram(picto_id: int): url = f"https://api.arasaac.org/v1/pictograms/{picto_id}" response = requests.get(url) if response.status_code == 200: return Image.open(BytesIO(response.content)) return None # Generate an image from a sequence of pictogram IDs def create_pictogram_image(pictogram_ids): pictogram_images = [] for picto_id in pictogram_ids: picto_image = fetch_pictogram(picto_id) if picto_image: pictogram_images.append(picto_image) # Concatenate all pictogram images widths, heights = zip(*(i.size for i in pictogram_images)) total_width = sum(widths) max_height = max(heights) final_image = Image.new('RGB', (total_width, max_height)) x_offset = 0 for img in pictogram_images: final_image.paste(img, (x_offset, 0)) x_offset += img.size[0] return final_image # Endpoint for health check @app.get("/health", summary="Health Check", response_description="Health status") def health_check(): return {"status": "healthy", "message": "API is up and running"} # Main translation endpoint @app.post("/v1/translate", summary="Translate Text to Pictograms", response_model=TranslationResponse) def translate(request: TranslationRequest): if request.language not in ["en", "fr", "ta"]: raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.") # Placeholder for unsupported languages if request.language in ["en", "ta"]: return TranslationResponse( language=request.language, src=request.src, tgt=None, image_base64=None, ) # Translate using French model inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device) translated_tokens = model.generate(**inputs) tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) # Map translated sentence to pictograms words = tgt_sentence.split() pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words] pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id] if pictogram_ids: # Generate pictogram image final_image = create_pictogram_image(pictogram_ids) if final_image: img_byte_arr = BytesIO() final_image.save(img_byte_arr, format="PNG") encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") return TranslationResponse( language=request.language, src=request.src, tgt=tgt_sentence, image_base64=encoded_image, ) return TranslationResponse( language=request.language, src=request.src, tgt=tgt_sentence, image_base64=None, )