Spaces:
Sleeping
Sleeping
from PIL import Image | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import MBartForConditionalGeneration, MBartTokenizerFast | |
import torch | |
import requests, json, base64 | |
from io import BytesIO | |
# Define the input schema | |
class TranslationRequest(BaseModel): | |
src: str | |
language: str | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load the model and tokenizer | |
model_path = "feedlight42/mbart25-text2picto" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = MBartForConditionalGeneration.from_pretrained(model_path) | |
tokenizer = MBartTokenizerFast.from_pretrained(model_path) | |
model = model.to(device) | |
# Load the pictogram dictionary from the JSON file | |
with open('pictogram_vocab.json', 'r') as f: | |
pictogram_dict = json.load(f) | |
# Fetch a pictogram image from the ARASAAC API | |
def fetch_pictogram(picto_id): | |
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}" | |
response = requests.get(url) | |
return Image.open(BytesIO(response.content)) | |
# Generate an image from a sequence of pictogram IDs | |
def create_pictogram_image(pictogram_ids): | |
pictogram_images = [] | |
for picto_id in pictogram_ids: | |
picto_image = fetch_pictogram(picto_id) | |
if picto_image: | |
pictogram_images.append(picto_image) | |
# Concatenate all pictogram images | |
widths, heights = zip(*(i.size for i in pictogram_images)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
final_image = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for img in pictogram_images: | |
final_image.paste(img, (x_offset, 0)) | |
x_offset += img.size[0] | |
return final_image | |
def translate(request: TranslationRequest): | |
""" | |
Translate text to target language and generate pictogram tokens. | |
""" | |
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device) | |
# Generate translation | |
translated_tokens = model.generate(**inputs) | |
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
# Process the translated sentence and map words to pictograms | |
words = tgt_sentence.split() # Split sentence into words | |
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words] | |
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id is not None] # Remove None values | |
# Check if there are pictogram IDs | |
if pictogram_ids: | |
# Generate the pictogram image | |
final_image = create_pictogram_image(pictogram_ids) | |
# Convert image to base64 | |
img_byte_arr = BytesIO() | |
final_image.save(img_byte_arr, format='PNG') | |
img_byte_arr.seek(0) | |
encoded_image = base64.b64encode(img_byte_arr.read()).decode('utf-8') | |
return { | |
"src": request.src, | |
"tgt": tgt_sentence, | |
"pictograms": pictogram_ids, | |
"image_base64": encoded_image | |
} | |
else: | |
# Return a response without an image if no pictogram IDs are found | |
return { | |
"src": request.src, | |
"tgt": tgt_sentence, | |
"pictograms": pictogram_ids, | |
"image_base64": None # No image if no pictograms were found | |
} | |