Spaces:
Running
Running
File size: 3,267 Bytes
60ffe25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from PIL import Image
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
import torch
import requests, json, base64
from io import BytesIO
# Define the input schema
class TranslationRequest(BaseModel):
src: str
language: str
# Initialize FastAPI app
app = FastAPI()
# Load the model and tokenizer
model_path = "feedlight42/mbart25-text2picto"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MBartForConditionalGeneration.from_pretrained(model_path)
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
model = model.to(device)
# Load the pictogram dictionary from the JSON file
with open('pictogram_vocab.json', 'r') as f:
pictogram_dict = json.load(f)
# Fetch a pictogram image from the ARASAAC API
def fetch_pictogram(picto_id):
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
response = requests.get(url)
return Image.open(BytesIO(response.content))
# Generate an image from a sequence of pictogram IDs
def create_pictogram_image(pictogram_ids):
pictogram_images = []
for picto_id in pictogram_ids:
picto_image = fetch_pictogram(picto_id)
if picto_image:
pictogram_images.append(picto_image)
# Concatenate all pictogram images
widths, heights = zip(*(i.size for i in pictogram_images))
total_width = sum(widths)
max_height = max(heights)
final_image = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in pictogram_images:
final_image.paste(img, (x_offset, 0))
x_offset += img.size[0]
return final_image
@app.post("/translate")
def translate(request: TranslationRequest):
"""
Translate text to target language and generate pictogram tokens.
"""
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
# Generate translation
translated_tokens = model.generate(**inputs)
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
# Process the translated sentence and map words to pictograms
words = tgt_sentence.split() # Split sentence into words
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id is not None] # Remove None values
# Check if there are pictogram IDs
if pictogram_ids:
# Generate the pictogram image
final_image = create_pictogram_image(pictogram_ids)
# Convert image to base64
img_byte_arr = BytesIO()
final_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
encoded_image = base64.b64encode(img_byte_arr.read()).decode('utf-8')
return {
"src": request.src,
"tgt": tgt_sentence,
"pictograms": pictogram_ids,
"image_base64": encoded_image
}
else:
# Return a response without an image if no pictogram IDs are found
return {
"src": request.src,
"tgt": tgt_sentence,
"pictograms": pictogram_ids,
"image_base64": None # No image if no pictograms were found
}
|