Spaces:
Running
Running
Commit
·
1f74d92
1
Parent(s):
6cbe35a
implement logging for response time
Browse files
main.py
CHANGED
@@ -1,12 +1,24 @@
|
|
|
|
1 |
from PIL import Image
|
2 |
-
from fastapi import FastAPI, HTTPException
|
|
|
|
|
3 |
from pydantic import BaseModel, Field
|
4 |
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
|
5 |
import torch
|
6 |
import requests, json, base64
|
7 |
from io import BytesIO
|
8 |
-
from typing import List, Optional
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Initialize FastAPI app with versioning
|
12 |
app = FastAPI(
|
@@ -15,18 +27,69 @@ app = FastAPI(
|
|
15 |
description="An API for converting text to pictograms, supporting English, French, and Tamil.",
|
16 |
)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Define schemas for requests and responses
|
|
|
20 |
class TranslationRequest(BaseModel):
|
21 |
src: str = Field(..., description="Source text to be translated.")
|
22 |
language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
|
23 |
|
24 |
class TranslationResponse(BaseModel):
|
25 |
-
language: str
|
26 |
-
src: str
|
27 |
-
tgt: Optional[str] = None
|
28 |
-
pictogram_ids: List[Optional[int]] = None
|
29 |
-
image_base64: Optional[str] = None
|
30 |
|
31 |
|
32 |
# Load the model and tokenizer
|
@@ -37,20 +100,32 @@ model = MBartForConditionalGeneration.from_pretrained(model_path)
|
|
37 |
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
|
38 |
model = model.to(device)
|
39 |
|
|
|
|
|
|
|
40 |
# Load the pictogram dictionary from the JSON file
|
41 |
with open('pictogram_vocab.json', 'r') as f:
|
42 |
pictogram_dict = json.load(f)
|
43 |
|
44 |
-
#
|
45 |
def fetch_pictogram(picto_id: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
|
47 |
response = requests.get(url)
|
48 |
if response.status_code == 200:
|
49 |
-
|
|
|
|
|
|
|
50 |
return None
|
51 |
|
52 |
-
# Generate an image from a sequence of pictogram IDs
|
53 |
-
def create_pictogram_image(pictogram_ids):
|
54 |
pictogram_images = []
|
55 |
for picto_id in pictogram_ids:
|
56 |
picto_image = fetch_pictogram(picto_id)
|
@@ -62,26 +137,56 @@ def create_pictogram_image(pictogram_ids):
|
|
62 |
total_width = sum(widths)
|
63 |
max_height = max(heights)
|
64 |
|
65 |
-
|
|
|
66 |
x_offset = 0
|
67 |
for img in pictogram_images:
|
68 |
-
|
|
|
|
|
|
|
|
|
69 |
x_offset += img.size[0]
|
70 |
|
71 |
return final_image
|
72 |
|
73 |
|
74 |
# Endpoint for health check
|
|
|
75 |
@app.get("/health", summary="Health Check", response_description="Health status")
|
76 |
-
def health_check():
|
77 |
return {"status": "healthy", "message": "API is up and running"}
|
78 |
|
79 |
# Main translation endpoint
|
80 |
-
@app.post("/v1/translate", summary="Translate Text to Pictograms",
|
81 |
-
def translate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if request.language not in ["en", "fr", "ta"]:
|
83 |
raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Placeholder for unsupported languages
|
86 |
if request.language in ["en", "ta"]:
|
87 |
return TranslationResponse(
|
@@ -104,7 +209,7 @@ def translate(request: TranslationRequest):
|
|
104 |
|
105 |
if pictogram_ids:
|
106 |
# Generate pictogram image
|
107 |
-
final_image = create_pictogram_image(pictogram_ids)
|
108 |
if final_image:
|
109 |
img_byte_arr = BytesIO()
|
110 |
final_image.save(img_byte_arr, format="PNG")
|
|
|
1 |
+
import time
|
2 |
from PIL import Image
|
3 |
+
from fastapi import FastAPI, HTTPException, Query
|
4 |
+
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
5 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
6 |
from pydantic import BaseModel, Field
|
7 |
from transformers import MBartForConditionalGeneration, MBartTokenizerFast
|
8 |
import torch
|
9 |
import requests, json, base64
|
10 |
from io import BytesIO
|
11 |
+
from typing import List, Optional, Tuple
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
import logging
|
15 |
+
|
16 |
+
# Suppress specific FutureWarning from huggingface_hub
|
17 |
+
warnings.filterwarnings(
|
18 |
+
"ignore",
|
19 |
+
category=FutureWarning,
|
20 |
+
module="huggingface_hub.file_download"
|
21 |
+
)
|
22 |
|
23 |
# Initialize FastAPI app with versioning
|
24 |
app = FastAPI(
|
|
|
27 |
description="An API for converting text to pictograms, supporting English, French, and Tamil.",
|
28 |
)
|
29 |
|
30 |
+
# Set up custom logging format
|
31 |
+
logging.basicConfig(
|
32 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
33 |
+
level=logging.INFO
|
34 |
+
)
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
# Custom Middleware for request time logging
|
39 |
+
class RequestTimingMiddleware(BaseHTTPMiddleware):
|
40 |
+
async def dispatch(self, request, call_next):
|
41 |
+
start_time = time.time()
|
42 |
+
logger.info(f"Request received: {request.method} {request.url}")
|
43 |
+
|
44 |
+
# Process the request
|
45 |
+
response = await call_next(request)
|
46 |
+
|
47 |
+
end_time = time.time()
|
48 |
+
processing_time = end_time - start_time
|
49 |
+
|
50 |
+
# Log the timing
|
51 |
+
logger.info(f"Request processed: {request.method} {request.url}")
|
52 |
+
logger.info(f"Processing time: {processing_time:.4f}s")
|
53 |
+
|
54 |
+
# Return the response
|
55 |
+
return response
|
56 |
+
|
57 |
+
# Add the middleware to the app
|
58 |
+
app.add_middleware(RequestTimingMiddleware)
|
59 |
+
|
60 |
+
|
61 |
+
COLORS = {
|
62 |
+
"white": (255, 255, 255),
|
63 |
+
"black": (0, 0, 0),
|
64 |
+
"red": (255, 0, 0),
|
65 |
+
"green": (0, 255, 0),
|
66 |
+
"blue": (0, 0, 255),
|
67 |
+
"yellow": (255, 255, 0),
|
68 |
+
"cyan": (0, 255, 255),
|
69 |
+
"magenta": (255, 0, 255),
|
70 |
+
"gray": (128, 128, 128),
|
71 |
+
"orange": (255, 165, 0),
|
72 |
+
"purple": (128, 0, 128),
|
73 |
+
"brown": (165, 42, 42),
|
74 |
+
"pink": (255, 192, 203),
|
75 |
+
"lime": (0, 255, 0),
|
76 |
+
"teal": (0, 128, 128),
|
77 |
+
"navy": (0, 0, 128)
|
78 |
+
# Add more colors as needed
|
79 |
+
}
|
80 |
|
81 |
# Define schemas for requests and responses
|
82 |
+
|
83 |
class TranslationRequest(BaseModel):
|
84 |
src: str = Field(..., description="Source text to be translated.")
|
85 |
language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
|
86 |
|
87 |
class TranslationResponse(BaseModel):
|
88 |
+
language: str = Field(..., description="Language of the source text.")
|
89 |
+
src: str = Field(..., description="Source text in the original language.")
|
90 |
+
tgt: Optional[str] = Field(None, description="Translated text in the original language.")
|
91 |
+
pictogram_ids: Optional[List[Optional[int]]] = Field(None, description="List of pictogram IDs corresponding to the translation.")
|
92 |
+
image_base64: Optional[str] = Field(None, description="Base64-encoded image of the pictograms, if generated.")
|
93 |
|
94 |
|
95 |
# Load the model and tokenizer
|
|
|
100 |
tokenizer = MBartTokenizerFast.from_pretrained(model_path)
|
101 |
model = model.to(device)
|
102 |
|
103 |
+
# Folder to store local pictogram images
|
104 |
+
pictogram_folder = 'pictogram_images'
|
105 |
+
|
106 |
# Load the pictogram dictionary from the JSON file
|
107 |
with open('pictogram_vocab.json', 'r') as f:
|
108 |
pictogram_dict = json.load(f)
|
109 |
|
110 |
+
# Function to fetch a pictogram image from the ARASAAC API, or use local file if present
|
111 |
def fetch_pictogram(picto_id: int):
|
112 |
+
# Check if the pictogram image exists locally
|
113 |
+
image_path = os.path.join(pictogram_folder, f"{picto_id}.png")
|
114 |
+
if os.path.exists(image_path):
|
115 |
+
return Image.open(image_path)
|
116 |
+
|
117 |
+
# If the image is not local, request from ARASAAC API
|
118 |
url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
|
119 |
response = requests.get(url)
|
120 |
if response.status_code == 200:
|
121 |
+
# Save the image locally
|
122 |
+
img = Image.open(BytesIO(response.content))
|
123 |
+
img.save(image_path) # Save the image to the local folder
|
124 |
+
return img
|
125 |
return None
|
126 |
|
127 |
+
# Generate an image from a sequence of pictogram IDs with a customizable background color
|
128 |
+
def create_pictogram_image(pictogram_ids, background_color=(255, 0, 0)): # Default to red background
|
129 |
pictogram_images = []
|
130 |
for picto_id in pictogram_ids:
|
131 |
picto_image = fetch_pictogram(picto_id)
|
|
|
137 |
total_width = sum(widths)
|
138 |
max_height = max(heights)
|
139 |
|
140 |
+
# Create a new image with the specified background color (RGBA with alpha channel)
|
141 |
+
final_image = Image.new('RGBA', (total_width, max_height), background_color + (255,)) # RGBA with alpha channel
|
142 |
x_offset = 0
|
143 |
for img in pictogram_images:
|
144 |
+
img = img.convert("RGBA") # Ensure the image has an alpha channel
|
145 |
+
img_with_bg = Image.new("RGBA", img.size, background_color + (255,)) # Create a solid background
|
146 |
+
img_with_bg.paste(img, (0, 0), img) # Paste the image on top of the background
|
147 |
+
|
148 |
+
final_image.paste(img_with_bg, (x_offset, 0), img_with_bg)
|
149 |
x_offset += img.size[0]
|
150 |
|
151 |
return final_image
|
152 |
|
153 |
|
154 |
# Endpoint for health check
|
155 |
+
@app.get("/")
|
156 |
@app.get("/health", summary="Health Check", response_description="Health status")
|
157 |
+
async def health_check():
|
158 |
return {"status": "healthy", "message": "API is up and running"}
|
159 |
|
160 |
# Main translation endpoint
|
161 |
+
@app.post("/v1/translate", summary="Translate Text to Pictograms", description="Translates text from a source language to a target language and converts the translation into pictograms. Optionally customize the background color of the generated pictogram images.")
|
162 |
+
async def translate(
|
163 |
+
request: TranslationRequest,
|
164 |
+
backgroundColor: Optional[str] = Query("white"),
|
165 |
+
backgroundColorRGB: Optional[Tuple[int, int, int]] = Query(None)
|
166 |
+
):
|
167 |
+
"""
|
168 |
+
Translate the provided source text into pictograms and return a corresponding image.
|
169 |
+
|
170 |
+
- **src**: The source text to be translated.
|
171 |
+
- **language**: The source language. Accepted values are 'en', 'fr', 'ta'.
|
172 |
+
- **backgroundColor**: (Optional) Background color for the pictogram image, specified by name (e.g., 'red'). Default is 'white'.
|
173 |
+
- **backgroundColorRGB**: (Optional) Background color for the image in RGB format (e.g., (255, 0, 0) for red). This overrides backgroundColor.
|
174 |
+
"""
|
175 |
+
|
176 |
+
# Ensure that both backgroundColor and backgroundColorRGB are not provided simultaneously
|
177 |
+
if backgroundColor and backgroundColorRGB:
|
178 |
+
raise HTTPException(status_code=400, detail="You cannot provide both backgroundColor and backgroundColorRGB at the same time.")
|
179 |
+
|
180 |
+
# Ensure that the given language is a valid one
|
181 |
if request.language not in ["en", "fr", "ta"]:
|
182 |
raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
|
183 |
|
184 |
+
# Default to white RGB (255, 255, 255) if neither backgroundColor nor backgroundColorRGB is provided
|
185 |
+
if backgroundColorRGB: background_color = backgroundColorRGB
|
186 |
+
elif backgroundColor: background_color = COLORS.get(backgroundColor.lower(), (255, 255, 255))
|
187 |
+
else: background_color = (255, 255, 255)
|
188 |
+
|
189 |
+
# Temporary fix
|
190 |
# Placeholder for unsupported languages
|
191 |
if request.language in ["en", "ta"]:
|
192 |
return TranslationResponse(
|
|
|
209 |
|
210 |
if pictogram_ids:
|
211 |
# Generate pictogram image
|
212 |
+
final_image = create_pictogram_image(pictogram_ids, background_color)
|
213 |
if final_image:
|
214 |
img_byte_arr = BytesIO()
|
215 |
final_image.save(img_byte_arr, format="PNG")
|