from PIL import Image from fastapi import FastAPI from pydantic import BaseModel from transformers import MBartForConditionalGeneration, MBartTokenizerFast import torch import requests, json, base64 from io import BytesIO # Define the input schema class TranslationRequest(BaseModel): src: str language: str # Initialize FastAPI app app = FastAPI() # Load the model and tokenizer model_path = "feedlight42/mbart25-text2picto" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MBartForConditionalGeneration.from_pretrained(model_path) tokenizer = MBartTokenizerFast.from_pretrained(model_path) model = model.to(device) # Load the pictogram dictionary from the JSON file with open('pictogram_vocab.json', 'r') as f: pictogram_dict = json.load(f) # Fetch a pictogram image from the ARASAAC API def fetch_pictogram(picto_id): url = f"https://api.arasaac.org/v1/pictograms/{picto_id}" response = requests.get(url) return Image.open(BytesIO(response.content)) # Generate an image from a sequence of pictogram IDs def create_pictogram_image(pictogram_ids): pictogram_images = [] for picto_id in pictogram_ids: picto_image = fetch_pictogram(picto_id) if picto_image: pictogram_images.append(picto_image) # Concatenate all pictogram images widths, heights = zip(*(i.size for i in pictogram_images)) total_width = sum(widths) max_height = max(heights) final_image = Image.new('RGB', (total_width, max_height)) x_offset = 0 for img in pictogram_images: final_image.paste(img, (x_offset, 0)) x_offset += img.size[0] return final_image @app.post("/translate") def translate(request: TranslationRequest): """ Translate text to target language and generate pictogram tokens. """ inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device) # Generate translation translated_tokens = model.generate(**inputs) tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) # Process the translated sentence and map words to pictograms words = tgt_sentence.split() # Split sentence into words pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words] pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id is not None] # Remove None values # Check if there are pictogram IDs if pictogram_ids: # Generate the pictogram image final_image = create_pictogram_image(pictogram_ids) # Convert image to base64 img_byte_arr = BytesIO() final_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) encoded_image = base64.b64encode(img_byte_arr.read()).decode('utf-8') return { "src": request.src, "tgt": tgt_sentence, "pictograms": pictogram_ids, "image_base64": encoded_image } else: # Return a response without an image if no pictogram IDs are found return { "src": request.src, "tgt": tgt_sentence, "pictograms": pictogram_ids, "image_base64": None # No image if no pictograms were found }