text2picto / main.py
feedlight42's picture
better versioning, best pratices
aa9f8f2
raw
history blame
4.23 kB
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
import torch
import requests, json, base64
from io import BytesIO
from typing import Optional
# Initialize FastAPI app with versioning
app = FastAPI(
title="Text-to-Pictogram API",
version="1.0.0",
description="An API for converting text to pictograms, supporting English, French, and Tamil.",
)
# Define schemas for requests and responses
class TranslationRequest(BaseModel):
src: str = Field(..., description="Source text to be translated.")
language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
class TranslationResponse(BaseModel):
language: str
src: str
tgt: Optional[str] = None
image_base64: Optional[str] = None
# Load the model and tokenizer
model_path = "feedlight42/mbart25-text2picto"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MBartForConditionalGeneration.from_pretrained(model_path)
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
model = model.to(device)
# Load the pictogram dictionary from the JSON file
with open('pictogram_vocab.json', 'r') as f:
pictogram_dict = json.load(f)
# Fetch a pictogram image from the ARASAAC API
def fetch_pictogram(picto_id: int):
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
response = requests.get(url)
if response.status_code == 200:
return Image.open(BytesIO(response.content))
return None
# Generate an image from a sequence of pictogram IDs
def create_pictogram_image(pictogram_ids):
pictogram_images = []
for picto_id in pictogram_ids:
picto_image = fetch_pictogram(picto_id)
if picto_image:
pictogram_images.append(picto_image)
# Concatenate all pictogram images
widths, heights = zip(*(i.size for i in pictogram_images))
total_width = sum(widths)
max_height = max(heights)
final_image = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in pictogram_images:
final_image.paste(img, (x_offset, 0))
x_offset += img.size[0]
return final_image
# Endpoint for health check
@app.get("/health", summary="Health Check", response_description="Health status")
def health_check():
return {"status": "healthy", "message": "API is up and running"}
# Main translation endpoint
@app.post("/v1/translate", summary="Translate Text to Pictograms", response_model=TranslationResponse)
def translate(request: TranslationRequest):
if request.language not in ["en", "fr", "ta"]:
raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
# Placeholder for unsupported languages
if request.language in ["en", "ta"]:
return TranslationResponse(
language=request.language,
src=request.src,
tgt=None,
image_base64=None,
)
# Translate using French model
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
translated_tokens = model.generate(**inputs)
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
# Map translated sentence to pictograms
words = tgt_sentence.split()
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id]
if pictogram_ids:
# Generate pictogram image
final_image = create_pictogram_image(pictogram_ids)
if final_image:
img_byte_arr = BytesIO()
final_image.save(img_byte_arr, format="PNG")
encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
return TranslationResponse(
language=request.language,
src=request.src,
tgt=tgt_sentence,
image_base64=encoded_image,
)
return TranslationResponse(
language=request.language,
src=request.src,
tgt=tgt_sentence,
image_base64=None,
)