feedlight42 commited on
Commit
aa9f8f2
·
1 Parent(s): c006108

better versioning, best pratices

Browse files
Files changed (1) hide show
  1. main.py +66 -41
main.py CHANGED
@@ -1,18 +1,32 @@
1
  from PIL import Image
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
  from transformers import MBartForConditionalGeneration, MBartTokenizerFast
5
  import torch
6
  import requests, json, base64
7
  from io import BytesIO
 
8
 
9
- # Define the input schema
 
 
 
 
 
 
 
 
 
10
  class TranslationRequest(BaseModel):
11
- src: str
 
 
 
12
  language: str
 
 
 
13
 
14
- # Initialize FastAPI app
15
- app = FastAPI()
16
 
17
  # Load the model and tokenizer
18
  model_path = "feedlight42/mbart25-text2picto"
@@ -27,10 +41,12 @@ with open('pictogram_vocab.json', 'r') as f:
27
  pictogram_dict = json.load(f)
28
 
29
  # Fetch a pictogram image from the ARASAAC API
30
- def fetch_pictogram(picto_id):
31
  url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
32
  response = requests.get(url)
33
- return Image.open(BytesIO(response.content))
 
 
34
 
35
  # Generate an image from a sequence of pictogram IDs
36
  def create_pictogram_image(pictogram_ids):
@@ -54,44 +70,53 @@ def create_pictogram_image(pictogram_ids):
54
  return final_image
55
 
56
 
57
- @app.post("/translate")
 
 
 
 
 
 
58
  def translate(request: TranslationRequest):
59
- """
60
- Translate text to target language and generate pictogram tokens.
61
- """
62
- inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
 
 
 
 
 
 
 
63
 
64
- # Generate translation
 
65
  translated_tokens = model.generate(**inputs)
66
  tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
67
-
68
- # Process the translated sentence and map words to pictograms
69
- words = tgt_sentence.split() # Split sentence into words
70
  pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
71
- pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id is not None] # Remove None values
72
 
73
- # Check if there are pictogram IDs
74
  if pictogram_ids:
75
- # Generate the pictogram image
76
  final_image = create_pictogram_image(pictogram_ids)
77
-
78
- # Convert image to base64
79
- img_byte_arr = BytesIO()
80
- final_image.save(img_byte_arr, format='PNG')
81
- img_byte_arr.seek(0)
82
- encoded_image = base64.b64encode(img_byte_arr.read()).decode('utf-8')
83
-
84
- return {
85
- "src": request.src,
86
- "tgt": tgt_sentence,
87
- "pictograms": pictogram_ids,
88
- "image_base64": encoded_image
89
- }
90
- else:
91
- # Return a response without an image if no pictogram IDs are found
92
- return {
93
- "src": request.src,
94
- "tgt": tgt_sentence,
95
- "pictograms": pictogram_ids,
96
- "image_base64": None # No image if no pictograms were found
97
- }
 
1
  from PIL import Image
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel, Field
4
  from transformers import MBartForConditionalGeneration, MBartTokenizerFast
5
  import torch
6
  import requests, json, base64
7
  from io import BytesIO
8
+ from typing import Optional
9
 
10
+
11
+ # Initialize FastAPI app with versioning
12
+ app = FastAPI(
13
+ title="Text-to-Pictogram API",
14
+ version="1.0.0",
15
+ description="An API for converting text to pictograms, supporting English, French, and Tamil.",
16
+ )
17
+
18
+
19
+ # Define schemas for requests and responses
20
  class TranslationRequest(BaseModel):
21
+ src: str = Field(..., description="Source text to be translated.")
22
+ language: str = Field(..., description="Language of the source text. Accepted values: 'en', 'fr', 'ta'.")
23
+
24
+ class TranslationResponse(BaseModel):
25
  language: str
26
+ src: str
27
+ tgt: Optional[str] = None
28
+ image_base64: Optional[str] = None
29
 
 
 
30
 
31
  # Load the model and tokenizer
32
  model_path = "feedlight42/mbart25-text2picto"
 
41
  pictogram_dict = json.load(f)
42
 
43
  # Fetch a pictogram image from the ARASAAC API
44
+ def fetch_pictogram(picto_id: int):
45
  url = f"https://api.arasaac.org/v1/pictograms/{picto_id}"
46
  response = requests.get(url)
47
+ if response.status_code == 200:
48
+ return Image.open(BytesIO(response.content))
49
+ return None
50
 
51
  # Generate an image from a sequence of pictogram IDs
52
  def create_pictogram_image(pictogram_ids):
 
70
  return final_image
71
 
72
 
73
+ # Endpoint for health check
74
+ @app.get("/health", summary="Health Check", response_description="Health status")
75
+ def health_check():
76
+ return {"status": "healthy", "message": "API is up and running"}
77
+
78
+ # Main translation endpoint
79
+ @app.post("/v1/translate", summary="Translate Text to Pictograms", response_model=TranslationResponse)
80
  def translate(request: TranslationRequest):
81
+ if request.language not in ["en", "fr", "ta"]:
82
+ raise HTTPException(status_code=400, detail="Invalid language. Accepted values: 'en', 'fr', 'ta'.")
83
+
84
+ # Placeholder for unsupported languages
85
+ if request.language in ["en", "ta"]:
86
+ return TranslationResponse(
87
+ language=request.language,
88
+ src=request.src,
89
+ tgt=None,
90
+ image_base64=None,
91
+ )
92
 
93
+ # Translate using French model
94
+ inputs = tokenizer(request.src, return_tensors="pt", padding=True, truncation=True).to(device)
95
  translated_tokens = model.generate(**inputs)
96
  tgt_sentence = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
97
+
98
+ # Map translated sentence to pictograms
99
+ words = tgt_sentence.split()
100
  pictogram_ids = [pictogram_dict.get(word.lower(), None) for word in words]
101
+ pictogram_ids = [picto_id for picto_id in pictogram_ids if picto_id]
102
 
 
103
  if pictogram_ids:
104
+ # Generate pictogram image
105
  final_image = create_pictogram_image(pictogram_ids)
106
+ if final_image:
107
+ img_byte_arr = BytesIO()
108
+ final_image.save(img_byte_arr, format="PNG")
109
+ encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
110
+ return TranslationResponse(
111
+ language=request.language,
112
+ src=request.src,
113
+ tgt=tgt_sentence,
114
+ image_base64=encoded_image,
115
+ )
116
+
117
+ return TranslationResponse(
118
+ language=request.language,
119
+ src=request.src,
120
+ tgt=tgt_sentence,
121
+ image_base64=None,
122
+ )