Spaces:
Running
Running
from PIL import Image | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel, Field | |
from transformers import MBartForConditionalGeneration, MBartTokenizerFast | |
import torch | |
import requests, json, base64 | |
from io import BytesIO | |
from typing import Optional | |
# Initialize FastAPI app with versioning | |
app = FastAPI( | |
title="Text-to-Pictogram API", | |
version="1.0.0", | |
description="An API for converting text to pictograms, supporting English, French, and Tamil.", | |
) | |
# Define schemas for requests and responses | |
class TranslationRequest(BaseModel): | |
src: str = Field(..., description="Source text to be translated.") | |
language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.") | |
class TranslationResponse(BaseModel): | |
language: str | |
src: str | |
tgt: Optional[str] = None | |
image_base64: Optional[str] = None | |
# Load the model and tokenizer | |
model_path = "feedlight42/mbart25-text2picto" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = MBartForConditionalGeneration.from_pretrained(model_path) | |
tokenizer = MBartTokenizerFast.from_pretrained(model_path) | |
model = model.to(device) | |
# Load the pictogram dictionary from the JSON file | |
with open('pictogram_vocab.json', 'r') as f: | |
pictogram_dict = json.load(f) | |
# Fetch a pictogram image from the ARASAAC API | |
def fetch_pictogram(picto_id: int): | |
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}" | |
response = requests.get(url) | |
if response.status_code == 200: | |
return Image.open(BytesIO(response.content)) | |
return None | |
# Generate an image from a sequence of pictogram IDs | |
def create_pictogram_image(pictogram_ids): | |
pictogram_images = [] | |
for picto_id in pictogram_ids: | |
picto_image = fetch_pictogram(picto_id) | |
if picto_image: | |
pictogram_images.append(picto_image) | |
# Concatenate all pictogram images | |
widths, heights = zip(*(i.size for i in pictogram_images)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
final_image = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for img in pictogram_images: | |
final_image.paste(img, (x_offset, 0)) | |
x_offset += img.size[0] | |
return final_image | |
# Endpoint for health check | |
def health_check(): | |
return {"status": "healthy", "message": "API is up and running"} | |
# Main translation endpoint | |
def translate(request: TranslationRequest): | |
if request.language not in ["en", "fr", "ta"]: | |
raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.") | |
# Placeholder for unsupported languages | |
if request.language in ["en", "ta"]: | |
return TranslationResponse( | |
language=request.language, | |
src=request.src, | |
tgt=None, | |
image_base64=None, | |
) | |
# Translate using French model | |
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device) | |
translated_tokens = model.generate(**inputs) | |
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
# Map translated sentence to pictograms | |
words = tgt_sentence.split() | |
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words] | |
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id] | |
if pictogram_ids: | |
# Generate pictogram image | |
final_image = create_pictogram_image(pictogram_ids) | |
if final_image: | |
img_byte_arr = BytesIO() | |
final_image.save(img_byte_arr, format="PNG") | |
encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") | |
return TranslationResponse( | |
language=request.language, | |
src=request.src, | |
tgt=tgt_sentence, | |
image_base64=encoded_image, | |
) | |
return TranslationResponse( | |
language=request.language, | |
src=request.src, | |
tgt=tgt_sentence, | |
image_base64=None, | |
) |