File size: 3,267 Bytes
60ffe25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from PIL import Image
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
import torch
import requests, json, base64
from io import BytesIO

# Define the input schema
class TranslationRequest(BaseModel):
    src: str
    language: str

# Initialize FastAPI app
app = FastAPI()

# Load the model and tokenizer
model_path = "feedlight42/mbart25-text2picto"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MBartForConditionalGeneration.from_pretrained(model_path)
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
model = model.to(device)

# Load the pictogram dictionary from the JSON file
with open('pictogram_vocab.json', 'r') as f:
    pictogram_dict = json.load(f)

# Fetch a pictogram image from the ARASAAC API
def fetch_pictogram(picto_id):
    url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
    response = requests.get(url)
    return Image.open(BytesIO(response.content))

# Generate an image from a sequence of pictogram IDs
def create_pictogram_image(pictogram_ids):
    pictogram_images = []
    for picto_id in pictogram_ids:
        picto_image = fetch_pictogram(picto_id)
        if picto_image:
            pictogram_images.append(picto_image)

    # Concatenate all pictogram images
    widths, heights = zip(*(i.size for i in pictogram_images))
    total_width = sum(widths)
    max_height = max(heights)

    final_image = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for img in pictogram_images:
        final_image.paste(img, (x_offset, 0))
        x_offset += img.size[0]

    return final_image


@app.post("/translate")
def translate(request: TranslationRequest):
    """
    Translate text to target language and generate pictogram tokens.
    """
    inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)

    # Generate translation
    translated_tokens = model.generate(**inputs)
    tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
    
    # Process the translated sentence and map words to pictograms
    words = tgt_sentence.split()  # Split sentence into words
    pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
    pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id is not None]  # Remove None values

    # Check if there are pictogram IDs
    if pictogram_ids:
        # Generate the pictogram image
        final_image = create_pictogram_image(pictogram_ids)
        
        # Convert image to base64
        img_byte_arr = BytesIO()
        final_image.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        encoded_image = base64.b64encode(img_byte_arr.read()).decode('utf-8')

        return {
            "src": request.src,
            "tgt": tgt_sentence,
            "pictograms": pictogram_ids,
            "image_base64": encoded_image
        }
    else:
        # Return a response without an image if no pictogram IDs are found
        return {
            "src": request.src,
            "tgt": tgt_sentence,
            "pictograms": pictogram_ids,
            "image_base64": None  # No image if no pictograms were found
        }