File size: 4,395 Bytes
60ffe25
aa9f8f2
 
60ffe25
 
 
 
6cbe35a
60ffe25
aa9f8f2
 
 
 
 
 
 
 
 
 
60ffe25
aa9f8f2
 
 
 
60ffe25
aa9f8f2
 
6cbe35a
aa9f8f2
60ffe25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9f8f2
60ffe25
 
aa9f8f2
 
 
60ffe25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9f8f2
 
 
 
 
 
 
60ffe25
aa9f8f2
 
 
 
 
 
 
 
 
6cbe35a
aa9f8f2
 
60ffe25
aa9f8f2
 
60ffe25
 
aa9f8f2
 
 
60ffe25
aa9f8f2
60ffe25
 
aa9f8f2
60ffe25
aa9f8f2
 
 
 
 
 
 
 
6cbe35a
aa9f8f2
 
 
 
 
 
 
6cbe35a
aa9f8f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
import torch
import requests, json, base64
from io import BytesIO
from typing import List, Optional


# Initialize FastAPI app with versioning
app = FastAPI(
    title="Text-to-Pictogram API",
    version="1.0.0",
    description="An API for converting text to pictograms, supporting English, French, and Tamil.",
)


# Define schemas for requests and responses
class TranslationRequest(BaseModel):
    src: str = Field(..., description="Source text to be translated.")
    language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")

class TranslationResponse(BaseModel):
    language: str
    src: str
    tgt: Optional[str] = None
    pictogram_ids: List[Optional[int]] = None
    image_base64: Optional[str] = None


# Load the model and tokenizer
model_path = "feedlight42/mbart25-text2picto"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MBartForConditionalGeneration.from_pretrained(model_path)
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
model = model.to(device)

# Load the pictogram dictionary from the JSON file
with open('pictogram_vocab.json', 'r') as f:
    pictogram_dict = json.load(f)

# Fetch a pictogram image from the ARASAAC API
def fetch_pictogram(picto_id: int):
    url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
    response = requests.get(url)
    if response.status_code == 200:
        return Image.open(BytesIO(response.content))
    return None

# Generate an image from a sequence of pictogram IDs
def create_pictogram_image(pictogram_ids):
    pictogram_images = []
    for picto_id in pictogram_ids:
        picto_image = fetch_pictogram(picto_id)
        if picto_image:
            pictogram_images.append(picto_image)

    # Concatenate all pictogram images
    widths, heights = zip(*(i.size for i in pictogram_images))
    total_width = sum(widths)
    max_height = max(heights)

    final_image = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for img in pictogram_images:
        final_image.paste(img, (x_offset, 0))
        x_offset += img.size[0]

    return final_image


# Endpoint for health check
@app.get("/health", summary="Health Check", response_description="Health status")
def health_check():
    return {"status": "healthy", "message": "API is up and running"}

# Main translation endpoint
@app.post("/v1/translate", summary="Translate Text to Pictograms", response_model=TranslationResponse)
def translate(request: TranslationRequest):
    if request.language not in ["en", "fr", "ta"]:
        raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")

    # Placeholder for unsupported languages
    if request.language in ["en", "ta"]:
        return TranslationResponse(
            language=request.language,
            src=request.src,
            tgt=None,
            pictogram_ids=[],
            image_base64=None,
        )

    # Translate using French model
    inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
    translated_tokens = model.generate(**inputs)
    tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

    # Map translated sentence to pictograms
    words = tgt_sentence.split()
    pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
    pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id]

    if pictogram_ids:
        # Generate pictogram image
        final_image = create_pictogram_image(pictogram_ids)
        if final_image:
            img_byte_arr = BytesIO()
            final_image.save(img_byte_arr, format="PNG")
            encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
            return TranslationResponse(
                language=request.language,
                src=request.src,
                tgt=tgt_sentence,
                pictogram_ids=pictogram_ids,
                image_base64=encoded_image,
            )

    return TranslationResponse(
        language=request.language,
        src=request.src,
        tgt=tgt_sentence,
        pictogram_ids=pictogram_ids,
        image_base64=None,
    )