feedlight42 commited on
Commit
1f74d92
·
1 Parent(s): 6cbe35a

implement logging for response time

Browse files
Files changed (1) hide show
  1. main.py +123 -18
main.py CHANGED
@@ -1,12 +1,24 @@
 
1
  from PIL import Image
2
- from fastapi import FastAPI, HTTPException
 
 
3
  from pydantic import BaseModel, Field
4
  from transformers import MBartForConditionalGeneration, MBartTokenizerFast
5
  import torch
6
  import requests, json, base64
7
  from io import BytesIO
8
- from typing import List, Optional
9
-
 
 
 
 
 
 
 
 
 
10
 
11
  # Initialize FastAPI app with versioning
12
  app = FastAPI(
@@ -15,18 +27,69 @@ app = FastAPI(
15
  description="An API for converting text to pictograms, supporting English, French, and Tamil.",
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Define schemas for requests and responses
 
20
  class TranslationRequest(BaseModel):
21
  src: str = Field(..., description="Source text to be translated.")
22
  language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
23
 
24
  class TranslationResponse(BaseModel):
25
- language: str
26
- src: str
27
- tgt: Optional[str] = None
28
- pictogram_ids: List[Optional[int]] = None
29
- image_base64: Optional[str] = None
30
 
31
 
32
  # Load the model and tokenizer
@@ -37,20 +100,32 @@ model = MBartForConditionalGeneration.from_pretrained(model_path)
37
  tokenizer = MBartTokenizerFast.from_pretrained(model_path)
38
  model = model.to(device)
39
 
 
 
 
40
  # Load the pictogram dictionary from the JSON file
41
  with open('pictogram_vocab.json', 'r') as f:
42
  pictogram_dict = json.load(f)
43
 
44
- # Fetch a pictogram image from the ARASAAC API
45
  def fetch_pictogram(picto_id: int):
 
 
 
 
 
 
46
  url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
47
  response = requests.get(url)
48
  if response.status_code == 200:
49
- return Image.open(BytesIO(response.content))
 
 
 
50
  return None
51
 
52
- # Generate an image from a sequence of pictogram IDs
53
- def create_pictogram_image(pictogram_ids):
54
  pictogram_images = []
55
  for picto_id in pictogram_ids:
56
  picto_image = fetch_pictogram(picto_id)
@@ -62,26 +137,56 @@ def create_pictogram_image(pictogram_ids):
62
  total_width = sum(widths)
63
  max_height = max(heights)
64
 
65
- final_image = Image.new('RGB', (total_width, max_height))
 
66
  x_offset = 0
67
  for img in pictogram_images:
68
- final_image.paste(img, (x_offset, 0))
 
 
 
 
69
  x_offset += img.size[0]
70
 
71
  return final_image
72
 
73
 
74
  # Endpoint for health check
 
75
  @app.get("/health", summary="Health Check", response_description="Health status")
76
- def health_check():
77
  return {"status": "healthy", "message": "API is up and running"}
78
 
79
  # Main translation endpoint
80
- @app.post("/v1/translate", summary="Translate Text to Pictograms", response_model=TranslationResponse)
81
- def translate(request: TranslationRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if request.language not in ["en", "fr", "ta"]:
83
  raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
84
 
 
 
 
 
 
 
85
  # Placeholder for unsupported languages
86
  if request.language in ["en", "ta"]:
87
  return TranslationResponse(
@@ -104,7 +209,7 @@ def translate(request: TranslationRequest):
104
 
105
  if pictogram_ids:
106
  # Generate pictogram image
107
- final_image = create_pictogram_image(pictogram_ids)
108
  if final_image:
109
  img_byte_arr = BytesIO()
110
  final_image.save(img_byte_arr, format="PNG")
 
1
+ import time
2
  from PIL import Image
3
+ from fastapi import FastAPI, HTTPException, Query
4
+ from fastapi.middleware.trustedhost import TrustedHostMiddleware
5
+ from starlette.middleware.base import BaseHTTPMiddleware
6
  from pydantic import BaseModel, Field
7
  from transformers import MBartForConditionalGeneration, MBartTokenizerFast
8
  import torch
9
  import requests, json, base64
10
  from io import BytesIO
11
+ from typing import List, Optional, Tuple
12
+ import os
13
+ import warnings
14
+ import logging
15
+
16
+ # Suppress specific FutureWarning from huggingface_hub
17
+ warnings.filterwarnings(
18
+ "ignore",
19
+ category=FutureWarning,
20
+ module="huggingface_hub.file_download"
21
+ )
22
 
23
  # Initialize FastAPI app with versioning
24
  app = FastAPI(
 
27
  description="An API for converting text to pictograms, supporting English, French, and Tamil.",
28
  )
29
 
30
+ # Set up custom logging format
31
+ logging.basicConfig(
32
+ format="%(asctime)s - %(levelname)s - %(message)s",
33
+ level=logging.INFO
34
+ )
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Custom Middleware for request time logging
39
+ class RequestTimingMiddleware(BaseHTTPMiddleware):
40
+ async def dispatch(self, request, call_next):
41
+ start_time = time.time()
42
+ logger.info(f"Request received: {request.method} {request.url}")
43
+
44
+ # Process the request
45
+ response = await call_next(request)
46
+
47
+ end_time = time.time()
48
+ processing_time = end_time - start_time
49
+
50
+ # Log the timing
51
+ logger.info(f"Request processed: {request.method} {request.url}")
52
+ logger.info(f"Processing time: {processing_time:.4f}s")
53
+
54
+ # Return the response
55
+ return response
56
+
57
+ # Add the middleware to the app
58
+ app.add_middleware(RequestTimingMiddleware)
59
+
60
+
61
+ COLORS = {
62
+ "white": (255, 255, 255),
63
+ "black": (0, 0, 0),
64
+ "red": (255, 0, 0),
65
+ "green": (0, 255, 0),
66
+ "blue": (0, 0, 255),
67
+ "yellow": (255, 255, 0),
68
+ "cyan": (0, 255, 255),
69
+ "magenta": (255, 0, 255),
70
+ "gray": (128, 128, 128),
71
+ "orange": (255, 165, 0),
72
+ "purple": (128, 0, 128),
73
+ "brown": (165, 42, 42),
74
+ "pink": (255, 192, 203),
75
+ "lime": (0, 255, 0),
76
+ "teal": (0, 128, 128),
77
+ "navy": (0, 0, 128)
78
+ # Add more colors as needed
79
+ }
80
 
81
  # Define schemas for requests and responses
82
+
83
  class TranslationRequest(BaseModel):
84
  src: str = Field(..., description="Source text to be translated.")
85
  language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
86
 
87
  class TranslationResponse(BaseModel):
88
+ language: str = Field(..., description="Language of the source text.")
89
+ src: str = Field(..., description="Source text in the original language.")
90
+ tgt: Optional[str] = Field(None, description="Translated text in the original language.")
91
+ pictogram_ids: Optional[List[Optional[int]]] = Field(None, description="List of pictogram IDs corresponding to the translation.")
92
+ image_base64: Optional[str] = Field(None, description="Base64-encoded image of the pictograms, if generated.")
93
 
94
 
95
  # Load the model and tokenizer
 
100
  tokenizer = MBartTokenizerFast.from_pretrained(model_path)
101
  model = model.to(device)
102
 
103
+ # Folder to store local pictogram images
104
+ pictogram_folder = 'pictogram_images'
105
+
106
  # Load the pictogram dictionary from the JSON file
107
  with open('pictogram_vocab.json', 'r') as f:
108
  pictogram_dict = json.load(f)
109
 
110
+ # Function to fetch a pictogram image from the ARASAAC API, or use local file if present
111
  def fetch_pictogram(picto_id: int):
112
+ # Check if the pictogram image exists locally
113
+ image_path = os.path.join(pictogram_folder, f"{picto_id}.png")
114
+ if os.path.exists(image_path):
115
+ return Image.open(image_path)
116
+
117
+ # If the image is not local, request from ARASAAC API
118
  url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
119
  response = requests.get(url)
120
  if response.status_code == 200:
121
+ # Save the image locally
122
+ img = Image.open(BytesIO(response.content))
123
+ img.save(image_path) # Save the image to the local folder
124
+ return img
125
  return None
126
 
127
+ # Generate an image from a sequence of pictogram IDs with a customizable background color
128
+ def create_pictogram_image(pictogram_ids, background_color=(255, 0, 0)): # Default to red background
129
  pictogram_images = []
130
  for picto_id in pictogram_ids:
131
  picto_image = fetch_pictogram(picto_id)
 
137
  total_width = sum(widths)
138
  max_height = max(heights)
139
 
140
+ # Create a new image with the specified background color (RGBA with alpha channel)
141
+ final_image = Image.new('RGBA', (total_width, max_height), background_color + (255,)) # RGBA with alpha channel
142
  x_offset = 0
143
  for img in pictogram_images:
144
+ img = img.convert("RGBA") # Ensure the image has an alpha channel
145
+ img_with_bg = Image.new("RGBA", img.size, background_color + (255,)) # Create a solid background
146
+ img_with_bg.paste(img, (0, 0), img) # Paste the image on top of the background
147
+
148
+ final_image.paste(img_with_bg, (x_offset, 0), img_with_bg)
149
  x_offset += img.size[0]
150
 
151
  return final_image
152
 
153
 
154
  # Endpoint for health check
155
+ @app.get("/")
156
  @app.get("/health", summary="Health Check", response_description="Health status")
157
+ async def health_check():
158
  return {"status": "healthy", "message": "API is up and running"}
159
 
160
  # Main translation endpoint
161
+ @app.post("/v1/translate", summary="Translate Text to Pictograms", description="Translates text from a source language to a target language and converts the translation into pictograms. Optionally customize the background color of the generated pictogram images.")
162
+ async def translate(
163
+ request: TranslationRequest,
164
+ backgroundColor: Optional[str] = Query("white"),
165
+ backgroundColorRGB: Optional[Tuple[int, int, int]] = Query(None)
166
+ ):
167
+ """
168
+ Translate the provided source text into pictograms and return a corresponding image.
169
+
170
+ - **src**: The source text to be translated.
171
+ - **language**: The source language. Accepted values are 'en', 'fr', 'ta'.
172
+ - **backgroundColor**: (Optional) Background color for the pictogram image, specified by name (e.g., 'red'). Default is 'white'.
173
+ - **backgroundColorRGB**: (Optional) Background color for the image in RGB format (e.g., (255, 0, 0) for red). This overrides backgroundColor.
174
+ """
175
+
176
+ # Ensure that both backgroundColor and backgroundColorRGB are not provided simultaneously
177
+ if backgroundColor and backgroundColorRGB:
178
+ raise HTTPException(status_code=400, detail="You cannot provide both backgroundColor and backgroundColorRGB at the same time.")
179
+
180
+ # Ensure that the given language is a valid one
181
  if request.language not in ["en", "fr", "ta"]:
182
  raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
183
 
184
+ # Default to white RGB (255, 255, 255) if neither backgroundColor nor backgroundColorRGB is provided
185
+ if backgroundColorRGB: background_color = backgroundColorRGB
186
+ elif backgroundColor: background_color = COLORS.get(backgroundColor.lower(), (255, 255, 255))
187
+ else: background_color = (255, 255, 255)
188
+
189
+ # Temporary fix
190
  # Placeholder for unsupported languages
191
  if request.language in ["en", "ta"]:
192
  return TranslationResponse(
 
209
 
210
  if pictogram_ids:
211
  # Generate pictogram image
212
+ final_image = create_pictogram_image(pictogram_ids, background_color)
213
  if final_image:
214
  img_byte_arr = BytesIO()
215
  final_image.save(img_byte_arr, format="PNG")