Spaces:
Sleeping
Sleeping
File size: 4,395 Bytes
60ffe25 aa9f8f2 60ffe25 6cbe35a 60ffe25 aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 6cbe35a aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 6cbe35a aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 60ffe25 aa9f8f2 6cbe35a aa9f8f2 6cbe35a aa9f8f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
import torch
import requests, json, base64
from io import BytesIO
from typing import List, Optional
# Initialize FastAPI app with versioning
app = FastAPI(
title="Text-to-Pictogram API",
version="1.0.0",
description="An API for converting text to pictograms, supporting English, French, and Tamil.",
)
# Define schemas for requests and responses
class TranslationRequest(BaseModel):
src: str = Field(..., description="Source text to be translated.")
language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
class TranslationResponse(BaseModel):
language: str
src: str
tgt: Optional[str] = None
pictogram_ids: List[Optional[int]] = None
image_base64: Optional[str] = None
# Load the model and tokenizer
model_path = "feedlight42/mbart25-text2picto"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MBartForConditionalGeneration.from_pretrained(model_path)
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
model = model.to(device)
# Load the pictogram dictionary from the JSON file
with open('pictogram_vocab.json', 'r') as f:
pictogram_dict = json.load(f)
# Fetch a pictogram image from the ARASAAC API
def fetch_pictogram(picto_id: int):
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
response = requests.get(url)
if response.status_code == 200:
return Image.open(BytesIO(response.content))
return None
# Generate an image from a sequence of pictogram IDs
def create_pictogram_image(pictogram_ids):
pictogram_images = []
for picto_id in pictogram_ids:
picto_image = fetch_pictogram(picto_id)
if picto_image:
pictogram_images.append(picto_image)
# Concatenate all pictogram images
widths, heights = zip(*(i.size for i in pictogram_images))
total_width = sum(widths)
max_height = max(heights)
final_image = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in pictogram_images:
final_image.paste(img, (x_offset, 0))
x_offset += img.size[0]
return final_image
# Endpoint for health check
@app.get("/health", summary="Health Check", response_description="Health status")
def health_check():
return {"status": "healthy", "message": "API is up and running"}
# Main translation endpoint
@app.post("/v1/translate", summary="Translate Text to Pictograms", response_model=TranslationResponse)
def translate(request: TranslationRequest):
if request.language not in ["en", "fr", "ta"]:
raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
# Placeholder for unsupported languages
if request.language in ["en", "ta"]:
return TranslationResponse(
language=request.language,
src=request.src,
tgt=None,
pictogram_ids=[],
image_base64=None,
)
# Translate using French model
inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
translated_tokens = model.generate(**inputs)
tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
# Map translated sentence to pictograms
words = tgt_sentence.split()
pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id]
if pictogram_ids:
# Generate pictogram image
final_image = create_pictogram_image(pictogram_ids)
if final_image:
img_byte_arr = BytesIO()
final_image.save(img_byte_arr, format="PNG")
encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
return TranslationResponse(
language=request.language,
src=request.src,
tgt=tgt_sentence,
pictogram_ids=pictogram_ids,
image_base64=encoded_image,
)
return TranslationResponse(
language=request.language,
src=request.src,
tgt=tgt_sentence,
pictogram_ids=pictogram_ids,
image_base64=None,
) |