text2picto / main.py
feedlight42's picture
initial commit
60ffe25
raw
history blame
3.27 kB
from PIL import Image
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
import torch
import requests, json, base64
from io import BytesIO
# Define the input schema
class TranslationRequest(BaseModel):
src: str
language: str
# Initialize FastAPI app
app = FastAPI()
# Load the model and tokenizer
model_path = "feedlight42/mbart25-text2picto"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MBartForConditionalGeneration.from_pretrained(model_path)
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
model = model.to(device)
# Load the pictogram dictionary from the JSON file
with open('pictogram_vocab.json', 'r') as f:
pictogram_dict = json.load(f)
# Fetch a pictogram image from the ARASAAC API
def fetch_pictogram(picto_id):
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
response = requests.get(url)
return Image.open(BytesIO(response.content))
# Generate an image from a sequence of pictogram IDs
def create_pictogram_image(pictogram_ids):
pictogram_images = []
for picto_id in pictogram_ids:
picto_image = fetch_pictogram(picto_id)
if picto_image:
pictogram_images.append(picto_image)
# Concatenate all pictogram images
widths, heights = zip(*(i.size for i in pictogram_images))
total_width = sum(widths)
max_height = max(heights)
final_image = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in pictogram_images:
final_image.paste(img, (x_offset, 0))
x_offset += img.size[0]
return final_image
@app.post("/translate")
def translate(request: TranslationRequest):
"""
Translate text to target language and generate pictogram tokens.
"""
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
# Generate translation
translated_tokens = model.generate(**inputs)
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
# Process the translated sentence and map words to pictograms
words = tgt_sentence.split() # Split sentence into words
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id is not None] # Remove None values
# Check if there are pictogram IDs
if pictogram_ids:
# Generate the pictogram image
final_image = create_pictogram_image(pictogram_ids)
# Convert image to base64
img_byte_arr = BytesIO()
final_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
encoded_image = base64.b64encode(img_byte_arr.read()).decode('utf-8')
return {
"src": request.src,
"tgt": tgt_sentence,
"pictograms": pictogram_ids,
"image_base64": encoded_image
}
else:
# Return a response without an image if no pictogram IDs are found
return {
"src": request.src,
"tgt": tgt_sentence,
"pictograms": pictogram_ids,
"image_base64": None # No image if no pictograms were found
}