Update app.py
Browse files
@@ -1,39 +1,142 @@
1 |
import os
2 |
import torch
3 |
from fastapi import FastAPI, HTTPException
4 |
from fastapi.responses import StreamingResponse
5 |
from pydantic import BaseModel
6 |
from transformers import (
7 |
8 |
9 |
10 |
11 |
12 |
13 |
from io import BytesIO
14 |
import boto3
15 |
from botocore.exceptions import NoCredentialsError
16 |
from huggingface_hub import snapshot_download
17 |
18 |
19 |
20 |
21 |
22 |
AWS_REGION = os.getenv("AWS_REGION")
23 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24 |
25 |
26 |
# Diccionario global de tokens y configuraciones
27 |
token_dict = {}
28 |
29 |
# Inicializaci贸n de la aplicaci贸n FastAPI
30 |
app = FastAPI()
31 |
32 |
# Modelo de solicitud
33 |
class GenerateRequest(BaseModel):
34 |
model_name: str
35 |
input_text: str
36 |
task_type: str
37 |
temperature: float = 1.0
38 |
max_new_tokens: int = 200
39 |
stream: bool = True
@@ -43,13 +146,52 @@ class GenerateRequest(BaseModel):
43 |
num_return_sequences: int = 1
44 |
do_sample: bool = True
45 |
chunk_delay: float = 0.0
46 |
47 |
48 |
# Clase para cargar y gestionar los modelos desde S3
49 |
class S3ModelLoader:
50 |
def __init__(self, bucket_name, aws_access_key_id
51 |
self.bucket_name = bucket_name
52 |
53 |
54 |
55 |
@@ -57,78 +199,110 @@ class S3ModelLoader:
57 |
58 |
59 |
def _get_s3_uri(self, model_name):
60 |
return f"
61 |
62 |
def load_model_and_tokenizer(self, model_name):
63 |
if model_name in token_dict:
64 |
return token_dict[model_name]
65 |
66 |
s3_uri = self._get_s3_uri(model_name)
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
85 |
86 |
# Asignar EOS y PAD token si no est谩n definidos
87 |
if tokenizer.eos_token_id is None:
88 |
tokenizer.eos_token_id = tokenizer.pad_token_id
89 |
90 |
91 |
92 |
"model": model,
93 |
"tokenizer": tokenizer,
94 |
"pad_token_id": tokenizer.pad_token_id,
95 |
"eos_token_id": tokenizer.eos_token_id
96 |
97 |
98 |
# Subir los archivos del modelo y tokenizer a S3
99 |
self.s3_client.upload_file(model_path, self.bucket_name, f'{model_name}/model')
100 |
self.s3_client.upload_file(f'{model_path}/tokenizer', self.bucket_name, f'{model_name}/tokenizer')
101 |
102 |
# Eliminar los archivos locales despu茅s de haber subido a S3
103 |
104 |
105 |
return token_dict[model_name]
106 |
except NoCredentialsError:
107 |
raise HTTPException(status_code=500, detail="AWS credentials not found.")
108 |
except Exception as e:
109 |
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
110 |
111 |
# Instanciaci贸n del cargador de modelos
112 |
113 |
114 |
# Funci贸n de generaci贸n de texto con streaming
115 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
116 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
117 |
input_length = encoded_input["input_ids"].shape[1]
118 |
remaining_tokens = max_length - input_length
119 |
120 |
if remaining_tokens <= 0:
121 |
yield ""
122 |
123 |
generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
124 |
125 |
def stop_criteria(input_ids, scores):
126 |
decoded_output = tokenizer.decode(
127 |
return decoded_output in stop_sequences
128 |
129 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
130 |
131 |
output_text = ""
132 |
outputs = model.generate(
133 |
134 |
@@ -142,82 +316,380 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
142 |
143 |
144 |
145 |
146 |
for output in outputs.sequences:
147 |
for token_id in output:
148 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
149 |
yield token
150 |
await asyncio.sleep(chunk_delay)
151 |
152 |
if stop_sequences and any(stop in output_text for stop in stop_sequences):
153 |
yield output_text
154 |
155 |
156 |
157 |
158 |
async def generate(request: GenerateRequest):
159 |
160 |
model_name = request.model_name
161 |
input_text = request.input_text
162 |
temperature = request.temperature
163 |
max_new_tokens = request.max_new_tokens
164 |
stream = request.stream
165 |
top_p = request.top_p
166 |
top_k = request.top_k
167 |
repetition_penalty = request.repetition_penalty
168 |
num_return_sequences = request.num_return_sequences
169 |
do_sample = request.do_sample
170 |
chunk_delay = request.chunk_delay
171 |
stop_sequences = request.stop_sequences
172 |
173 |
# Cargar el modelo y tokenizer desde S3 si no existe
174 |
model_data = model_loader.load_model_and_tokenizer(model_name)
175 |
model = model_data["model"]
176 |
tokenizer = model_data["tokenizer"]
177 |
pad_token_id = model_data["pad_token_id"]
178 |
eos_token_id = model_data["eos_token_id"]
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
return StreamingResponse(
194 |
stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
195 |
196 |
197 |
198 |
except Exception as e:
199 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
200 |
201 |
202 |
203 |
async def
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
img_byte_arr = BytesIO()
212 |
image.save(img_byte_arr, format="PNG")
213 |
214 |
215 |
216 |
217 |
except Exception as e:
218 |
219 |
220 |
# Ejecutar el servidor FastAPI con Uvicorn
221 |
if __name__ == "__main__":
222 |
223 |
1 |
import os
2 |
import torch
3 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, BackgroundTasks, Request, Query, APIRouter, Path, Body, status, Response, Header
4 |
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse, PlainTextResponse, RedirectResponse
5 |
from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr, ValidationError
6 |
from transformers import (
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
from io import BytesIO
29 |
import boto3
30 |
from botocore.exceptions import NoCredentialsError, ClientError
31 |
from huggingface_hub import snapshot_download
32 |
import asyncio
33 |
import tempfile
34 |
import hashlib
35 |
from PIL import Image
36 |
import base64
37 |
from typing import Optional, List, Union, Dict, Any
38 |
import uuid
39 |
import subprocess
40 |
import json
41 |
from starlette.middleware.cors import CORSMiddleware
42 |
import numpy as np
43 |
from typing import Dict, Any
44 |
from fastapi.staticfiles import StaticFiles
45 |
from fastapi.templating import Jinja2Templates
46 |
from fastapi.middleware.gzip import GZipMiddleware
47 |
from transformers import AutoImageProcessor, pipeline
48 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
49 |
from fastapi.security.api_key import APIKeyCookie
50 |
from fastapi import Depends, Security, status, APIRouter, UploadFile, File, Request
51 |
from fastapi.security import APIKeyHeader, OAuth2PasswordRequestForm
52 |
from passlib.context import CryptContext
53 |
from jose import JWTError, jwt
54 |
from datetime import datetime, timedelta
55 |
from starlette.requests import Request
56 |
import logging
57 |
from pydantic import EmailStr, constr, ValidationError
58 |
from database import insert_user, get_user, delete_user, update_user, create_db_and_table
59 |
from starlette.middleware import Middleware
60 |
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
61 |
from starlette.types import ASGIApp
62 |
import uvicorn
63 |
from starlette.responses import StreamingResponse
64 |
import logging
65 |
from pydantic import EmailStr, constr, ValidationError
66 |
from database import insert_user, get_user, delete_user, update_user, create_db_and_table, get_all_users
67 |
from starlette.middleware import Middleware
68 |
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
69 |
from starlette.types import ASGIApp
70 |
import uvicorn
71 |
from starlette.responses import StreamingResponse
72 |
import logging
73 |
from fastapi.exceptions import RequestValidationError
74 |
from fastapi import Request, status, Depends
75 |
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
76 |
from jose import JWTError, jwt
77 |
from passlib.context import CryptContext
78 |
from datetime import datetime, timedelta
79 |
from typing import Optional
80 |
81 |
#setting up logging
82 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
83 |
logger = logging.getLogger(__name__)
84 |
85 |
#JWT Settings
86 |
SECRET_KEY = os.getenv("SECRET_KEY")
87 |
if not SECRET_KEY:
88 |
raise ValueError("SECRET_KEY must be set.")
89 |
90 |
91 |
92 |
#Password Hashing
93 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
94 |
95 |
#Database connection - replace with your database setup
96 |
#Example using SQLite
97 |
import sqlite3
98 |
conn = sqlite3.connect('users.db')
99 |
cursor = conn.cursor()
100 |
101 |
102 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
103 |
104 |
#API Key
105 |
API_KEY = os.getenv("API_KEY")
106 |
api_key_header = APIKeyHeader(name="X-API-Key")
107 |
108 |
109 |
110 |
111 |
AWS_REGION = os.getenv("AWS_REGION")
112 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
113 |
114 |
TEMP_DIR = "/tmp"
115 |
STATIC_DIR = "static"
116 |
TEMPLATES = Jinja2Templates(directory="templates")
117 |
118 |
app = FastAPI()
119 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
120 |
121 |
122 |
origins = ["*"]
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
class User(BaseModel):
132 |
username: constr(min_length=3, max_length=50)
133 |
email: EmailStr
134 |
password: constr(min_length=8)
135 |
136 |
class GenerateRequest(BaseModel):
137 |
model_name: str
138 |
input_text: Optional[str] = Field(None, description="Input text for generation.")
139 |
task_type: str = Field(..., description="Type of generation task (text, image, audio, video, classification, translation, question-answering, speech-to-text, text-to-speech, image-segmentation, feature-extraction, token-classification, fill-mask, image-inpainting, image-super-resolution, object-detection, image-captioning, audio-transcription, summarization).")
140 |
temperature: float = 1.0
141 |
max_new_tokens: int = 200
142 |
stream: bool = True
146 |
num_return_sequences: int = 1
147 |
do_sample: bool = True
148 |
chunk_delay: float = 0.0
149 |
stop_sequences: List[str] = []
150 |
image_file: Optional[UploadFile] = None
151 |
source_language: Optional[str] = None
152 |
target_language: Optional[str] = None
153 |
context: Optional[str] = None
154 |
audio_file: Optional[UploadFile] = None
155 |
raw_input: Optional[Union[str, bytes]] = None # for feature extraction
156 |
masked_text: Optional[str] = None # for fill-mask
157 |
mask_image: Optional[UploadFile] = None # for image inpainting
158 |
low_res_image: Optional[UploadFile] = None # for image super-resolution
159 |
160 |
161 |
162 |
def validate_task_type(cls, value):
163 |
allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
164 |
if value not in allowed_types:
165 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
166 |
return value
167 |
168 |
169 |
def check_input(cls, values):
170 |
task_type = values.get("task_type")
171 |
if task_type == "text" and values.get("input_text") is None:
172 |
raise ValueError("input_text is required for text generation.")
173 |
elif task_type == "speech-to-text" and values.get("audio_file") is None:
174 |
raise ValueError("audio_file is required for speech-to-text.")
175 |
elif task_type == "classification" and values.get("image_file") is None:
176 |
raise ValueError("image_file is required for image classification.")
177 |
elif task_type == "image-segmentation" and values.get("image_file") is None:
178 |
raise ValueError("image_file is required for image segmentation.")
179 |
elif task_type == "feature-extraction" and values.get("raw_input") is None:
180 |
raise ValueError("raw_input is required for feature extraction.")
181 |
elif task_type == "fill-mask" and values.get("masked_text") is None:
182 |
raise ValueError("masked_text is required for fill-mask.")
183 |
elif task_type == "image-inpainting" and (values.get("image_file") is None or values.get("mask_image") is None):
184 |
raise ValueError("image_file and mask_image are required for image inpainting.")
185 |
elif task_type == "image-super-resolution" and values.get("low_res_image") is None:
186 |
raise ValueError("low_res_image is required for image super-resolution.")
187 |
return values
188 |
189 |
190 |
191 |
class S3ModelLoader:
192 |
def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
193 |
self.bucket_name = bucket_name
194 |
self.s3 = boto3.client(
195 |
196 |
197 |
199 |
200 |
201 |
def _get_s3_uri(self, model_name):
202 |
return f"{self.bucket_name}/{model_name.replace('/', '-')}"
203 |
204 |
def load_model_and_tokenizer(self, model_name, task_type):
205 |
s3_uri = self._get_s3_uri(model_name)
206 |
207 |
self.s3.head_object(Bucket=self.bucket_name, Key=f'{s3_uri}/config.json')
208 |
except ClientError as e:
209 |
if e.response['Error']['Code'] == '404':
210 |
with tempfile.TemporaryDirectory() as tmpdir:
211 |
model_path = snapshot_download(model_name, token=HUGGINGFACE_HUB_TOKEN, cache_dir=tmpdir)
212 |
self._upload_model_to_s3(model_path, s3_uri)
213 |
214 |
raise HTTPException(status_code=500, detail=f"Error accessing S3: {e}")
215 |
return self._load_from_s3(s3_uri, task_type)
216 |
217 |
def _upload_model_to_s3(self, model_path, s3_uri):
218 |
for root, _, files in os.walk(model_path):
219 |
for file in files:
220 |
local_path = os.path.join(root, file)
221 |
s3_path = os.path.join(s3_uri, os.path.relpath(local_path, model_path))
222 |
self.s3.upload_file(local_path, self.bucket_name, s3_path)
223 |
224 |
def _load_from_s3(self, s3_uri, task_type):
225 |
with tempfile.TemporaryDirectory() as tmpdir:
226 |
model_path = os.path.join(tmpdir, s3_uri)
227 |
os.makedirs(model_path, exist_ok=True)
228 |
self.s3.download_file(self.bucket_name, f"{s3_uri}/config.json", os.path.join(model_path, "config.json"))
229 |
if task_type == "text":
230 |
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True)
231 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
232 |
if tokenizer.eos_token_id is None:
233 |
tokenizer.eos_token_id = tokenizer.pad_token_id
234 |
return {"model": model, "tokenizer": tokenizer, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id}
235 |
elif task_type in ["image", "audio", "video"]:
236 |
processor = AutoProcessor.from_pretrained(model_path)
237 |
pipeline_function = pipeline(task_type, model=model_path, device=0 if torch.cuda.is_available() else -1, processor=processor)
238 |
return {"pipeline": pipeline_function}
239 |
elif task_type == "classification":
240 |
model = AutoModelForImageClassification.from_pretrained(model_path)
241 |
processor = AutoProcessor.from_pretrained(model_path)
242 |
return {"model": model, "processor": processor}
243 |
elif task_type == "translation":
244 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
245 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
246 |
return {"model": model, "tokenizer": tokenizer}
247 |
elif task_type == "question-answering":
248 |
model = AutoModelForQuestionAnswering.from_pretrained(model_path)
249 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
250 |
return {"model": model, "tokenizer": tokenizer}
251 |
elif task_type == "speech-to-text":
252 |
model = pipeline("automatic-speech-recognition", model=model_path, device=0 if torch.cuda.is_available() else -1)
253 |
return {"pipeline": model}
254 |
elif task_type == "text-to-speech":
255 |
model = pipeline("text-to-speech", model=model_path, device=0 if torch.cuda.is_available() else -1)
256 |
return {"pipeline": model}
257 |
elif task_type == "image-segmentation":
258 |
model = pipeline("image-segmentation", model=model_path, device=0 if torch.cuda.is_available() else -1)
259 |
return {"pipeline": model}
260 |
elif task_type == "feature-extraction":
261 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
262 |
return {"feature_extractor": feature_extractor}
263 |
elif task_type == "token-classification":
264 |
model = AutoModelForTokenClassification.from_pretrained(model_path)
265 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
266 |
return {"model": model, "tokenizer": tokenizer}
267 |
elif task_type == "fill-mask":
268 |
model = AutoModelForMaskedLM.from_pretrained(model_path)
269 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
270 |
return {"model": model, "tokenizer": tokenizer}
271 |
elif task_type == "image-inpainting":
272 |
model = pipeline("image-inpainting", model=model_path, device=0 if torch.cuda.is_available() else -1)
273 |
return {"pipeline": model}
274 |
elif task_type == "image-super-resolution":
275 |
model = pipeline("image-super-resolution", model=model_path, device=0 if torch.cuda.is_available() else -1)
276 |
return {"pipeline": model}
277 |
elif task_type == "object-detection":
278 |
model = pipeline("object-detection", model=model_path, device=0 if torch.cuda.is_available() else -1)
279 |
image_processor = AutoImageProcessor.from_pretrained(model_path)
280 |
return {"pipeline": model, "image_processor": image_processor}
281 |
elif task_type == "image-captioning":
282 |
model = pipeline("image-captioning", model=model_path, device=0 if torch.cuda.is_available() else -1)
283 |
return {"pipeline": model}
284 |
elif task_type == "audio-transcription":
285 |
model = pipeline("automatic-speech-recognition", model=model_path, device=0 if torch.cuda.is_available() else -1)
286 |
return {"pipeline": model}
287 |
elif task_type == "summarization":
288 |
model = pipeline("summarization", model=model_path, device=0 if torch.cuda.is_available() else -1)
289 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
290 |
return {"model": model, "tokenizer": tokenizer}
291 |
292 |
raise ValueError("Unsupported task type")
293 |
294 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
295 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
296 |
input_length = encoded_input["input_ids"].shape[1]
297 |
max_length = model.config.max_length
298 |
remaining_tokens = max_length - input_length
299 |
if remaining_tokens <= 0:
300 |
yield ""
301 |
generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
302 |
def stop_criteria(input_ids, scores):
303 |
decoded_output = tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
304 |
return decoded_output in stop_sequences
305 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
306 |
outputs = model.generate(
307 |
308 |
316 |
317 |
318 |
319 |
for output in outputs.sequences:
320 |
for token_id in output:
321 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
322 |
yield token
323 |
324 |
325 |
326 |
327 |
def get_model_data(request: GenerateRequest):
328 |
return model_loader.load_model_and_tokenizer(request.model_name, request.task_type)
329 |
330 |
async def verify_api_key(api_key: str = Depends(api_key_header)):
331 |
if api_key != API_KEY:
332 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
333 |
334 |
335 |
@app.post("/generate", dependencies=[Depends(verify_api_key)])
336 |
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data = Depends(get_model_data)):
337 |
338 |
device = "cuda" if torch.cuda.is_available() else "cpu"
339 |
if request.task_type == "text":
340 |
model = model_data["model"].to(device)
341 |
tokenizer = model_data["tokenizer"]
342 |
generation_config = GenerationConfig(
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
async def stream_with_tokens():
352 |
async for token in stream_text(model, tokenizer, request.input_text, generation_config, request.stop_sequences, device, request.chunk_delay):
353 |
yield f"Token: {token}\n"
354 |
return StreamingResponse(stream_with_tokens(), media_type="text/plain")
355 |
elif request.task_type in ["image", "audio", "video"]:
356 |
pipeline = model_data["pipeline"]
357 |
result = pipeline(request.input_text)
358 |
if request.task_type == "image":
359 |
image = result[0]
360 |
img_byte_arr = BytesIO()
361 |
image.save(img_byte_arr, format="PNG")
362 |
363 |
return StreamingResponse(img_byte_arr, media_type="image/png")
364 |
elif request.task_type == "audio":
365 |
audio = result[0]
366 |
audio_byte_arr = BytesIO()
367 |
audio.save(audio_byte_arr, format="wav")
368 |
369 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
370 |
elif request.task_type == "video":
371 |
video = result[0]
372 |
video_byte_arr = BytesIO()
373 |
video.save(video_byte_arr, format="mp4")
374 |
375 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
376 |
elif request.task_type == "classification":
377 |
if request.image_file is None:
378 |
raise HTTPException(status_code=400, detail="Image file is required for classification.")
379 |
contents = await request.image_file.read()
380 |
image = Image.open(BytesIO(contents)).convert("RGB")
381 |
model = model_data["model"].to(device)
382 |
processor = model_data["processor"]
383 |
inputs = processor(images=image, return_tensors="pt").to(device)
384 |
with torch.no_grad():
385 |
outputs = model(**inputs)
386 |
predicted_class_idx = outputs.logits.argmax().item()
387 |
predicted_class = model.config.id2label[predicted_class_idx]
388 |
return JSONResponse({"predicted_class": predicted_class})
389 |
elif request.task_type == "translation":
390 |
if request.source_language is None or request.target_language is None:
391 |
raise HTTPException(status_code=400, detail="Source and target languages are required for translation.")
392 |
model = model_data["model"].to(device)
393 |
tokenizer = model_data["tokenizer"]
394 |
inputs = tokenizer(request.input_text, return_tensors="pt").to(device)
395 |
with torch.no_grad():
396 |
outputs = model.generate(**inputs)
397 |
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
398 |
return JSONResponse({"translation": translation})
399 |
elif request.task_type == "question-answering":
400 |
if request.context is None:
401 |
raise HTTPException(status_code=400, detail="Context is required for question answering.")
402 |
model = model_data["model"].to(device)
403 |
tokenizer = model_data["tokenizer"]
404 |
inputs = tokenizer(question=request.input_text, context=request.context, return_tensors="pt").to(device)
405 |
with torch.no_grad():
406 |
outputs = model(**inputs)
407 |
answer_start = torch.argmax(outputs.start_logits)
408 |
answer_end = torch.argmax(outputs.end_logits) + 1
409 |
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
410 |
return JSONResponse({"answer": answer})
411 |
elif request.task_type == "speech-to-text":
412 |
if request.audio_file is None:
413 |
raise HTTPException(status_code=400, detail="Audio file is required for speech-to-text.")
414 |
contents = await request.audio_file.read()
415 |
pipeline = model_data["pipeline"]
416 |
417 |
transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
418 |
return JSONResponse({"transcription": transcription})
419 |
except Exception as e:
420 |
raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}")
421 |
422 |
elif request.task_type == "text-to-speech":
423 |
if not request.input_text:
424 |
raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
425 |
pipeline = model_data["pipeline"]
426 |
427 |
audio = pipeline(request.input_text)[0]
428 |
file_path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}.wav")
429 |
430 |
background_tasks.add_task(os.remove, file_path)
431 |
return FileResponse(file_path, media_type="audio/wav")
432 |
except Exception as e:
433 |
raise HTTPException(status_code=500, detail=f"Error during text-to-speech: {str(e)}")
434 |
435 |
elif request.task_type == "image-segmentation":
436 |
if request.image_file is None:
437 |
raise HTTPException(status_code=400, detail="Image file is required for image segmentation.")
438 |
contents = await request.image_file.read()
439 |
image = Image.open(BytesIO(contents)).convert("RGB")
440 |
pipeline = model_data["pipeline"]
441 |
result = pipeline(image)
442 |
mask = result[0]['mask']
443 |
mask_byte_arr = BytesIO()
444 |
mask.save(mask_byte_arr, format="PNG")
445 |
446 |
return StreamingResponse(mask_byte_arr, media_type="image/png")
447 |
elif request.task_type == "feature-extraction":
448 |
if request.raw_input is None:
449 |
raise HTTPException(status_code=400, detail="raw_input is required for feature extraction.")
450 |
feature_extractor = model_data["feature_extractor"]
451 |
452 |
if isinstance(request.raw_input, str):
453 |
inputs = feature_extractor(text=request.raw_input, return_tensors="pt")
454 |
elif isinstance(request.raw_input, bytes):
455 |
image = Image.open(BytesIO(request.raw_input)).convert("RGB")
456 |
inputs = feature_extractor(images=image, return_tensors="pt")
457 |
458 |
raise ValueError("Unsupported raw_input type.")
459 |
features = inputs.pixel_values # Adjust according to your feature extractor
460 |
return JSONResponse({"features": features.tolist()})
461 |
except Exception as fe:
462 |
raise HTTPException(status_code=400, detail=f"Error during feature extraction: {fe}")
463 |
elif request.task_type == "token-classification":
464 |
if request.input_text is None:
465 |
raise HTTPException(status_code=400, detail="Input text is required for token classification.")
466 |
model = model_data["model"].to(device)
467 |
tokenizer = model_data["tokenizer"]
468 |
inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True)
469 |
with torch.no_grad():
470 |
outputs = model(**inputs)
471 |
predictions = outputs.logits.argmax(dim=-1)
472 |
predicted_labels = [model.config.id2label[label_id] for label_id in predictions[0].tolist()]
473 |
return JSONResponse({"predicted_labels": predicted_labels})
474 |
elif request.task_type == "fill-mask":
475 |
if request.masked_text is None:
476 |
raise HTTPException(status_code=400, detail="masked_text is required for fill-mask.")
477 |
model = model_data["model"].to(device)
478 |
tokenizer = model_data["tokenizer"]
479 |
inputs = tokenizer(request.masked_text, return_tensors="pt")
480 |
with torch.no_grad():
481 |
outputs = model(**inputs)
482 |
logits = outputs.logits
483 |
masked_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
484 |
predicted_token_id = torch.argmax(logits[0, masked_index])
485 |
predicted_token = tokenizer.decode(predicted_token_id)
486 |
return JSONResponse({"predicted_token": predicted_token})
487 |
elif request.task_type == "image-inpainting":
488 |
if request.image_file is None or request.mask_image is None:
489 |
raise HTTPException(status_code=400, detail="image_file and mask_image are required for image inpainting.")
490 |
image_contents = await request.image_file.read()
491 |
mask_contents = await request.mask_image.read()
492 |
image = Image.open(BytesIO(image_contents)).convert("RGB")
493 |
mask = Image.open(BytesIO(mask_contents)).convert("L") # Assuming mask is grayscale
494 |
pipeline = model_data["pipeline"]
495 |
result = pipeline(image, mask)
496 |
inpainted_image = result[0]
497 |
img_byte_arr = BytesIO()
498 |
inpainted_image.save(img_byte_arr, format="PNG")
499 |
500 |
return StreamingResponse(img_byte_arr, media_type="image/png")
501 |
elif request.task_type == "image-super-resolution":
502 |
if request.low_res_image is None:
503 |
raise HTTPException(status_code=400, detail="low_res_image is required for image super-resolution.")
504 |
contents = await request.low_res_image.read()
505 |
image = Image.open(BytesIO(contents)).convert("RGB")
506 |
pipeline = model_data["pipeline"]
507 |
result = pipeline(image)
508 |
upscaled_image = result[0]
509 |
img_byte_arr = BytesIO()
510 |
upscaled_image.save(img_byte_arr, format="PNG")
511 |
512 |
return StreamingResponse(img_byte_arr, media_type="image/png")
513 |
elif request.task_type == "object-detection":
514 |
if request.image_file is None:
515 |
raise HTTPException(status_code=400, detail="Image file is required for object detection.")
516 |
contents = await request.image_file.read()
517 |
image = Image.open(BytesIO(contents)).convert("RGB")
518 |
pipeline = model_data["pipeline"]
519 |
image_processor = model_data["image_processor"]
520 |
inputs = image_processor(images=image, return_tensors="pt")
521 |
with torch.no_grad():
522 |
outputs = pipeline(image)
523 |
detections = outputs
524 |
return JSONResponse({"detections": detections})
525 |
elif request.task_type == "image-captioning":
526 |
if request.image_file is None:
527 |
raise HTTPException(status_code=400, detail="Image file is required for image captioning.")
528 |
contents = await request.image_file.read()
529 |
image = Image.open(BytesIO(contents)).convert("RGB")
530 |
pipeline = model_data["pipeline"]
531 |
caption = pipeline(image)[0]['generated_text']
532 |
return JSONResponse({"caption": caption})
533 |
elif request.task_type == "audio-transcription":
534 |
if request.audio_file is None:
535 |
raise HTTPException(status_code=400, detail="Audio file is required for audio transcription.")
536 |
537 |
contents = await request.audio_file.read()
538 |
pipeline = model_data["pipeline"]
539 |
540 |
transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
541 |
return JSONResponse({"transcription": transcription})
542 |
except Exception as e:
543 |
raise HTTPException(status_code=500, detail=f"Error during audio transcription (pipeline): {str(e)}")
544 |
except Exception as e:
545 |
raise HTTPException(status_code=500, detail=f"Error during audio transcription (file read): {str(e)}")
546 |
elif request.task_type == "summarization":
547 |
if request.input_text is None:
548 |
raise HTTPException(status_code=400, detail="Input text is required for summarization.")
549 |
model = model_data["model"].to(device)
550 |
tokenizer = model_data["tokenizer"]
551 |
inputs = tokenizer(request.input_text, return_tensors="pt", truncation=True, max_length=512) # added max_length for summarization
552 |
with torch.no_grad():
553 |
outputs = model.generate(**inputs)
554 |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
555 |
return JSONResponse({"summary": summary})
556 |
557 |
558 |
raise HTTPException(status_code=500, detail=f"Unsupported task type")
559 |
except Exception as e:
560 |
logger.exception(f"Internal server error: {str(e)}")
561 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
562 |
563 |
564 |
@app.get("/", response_class=HTMLResponse)
565 |
async def root(request: Request):
566 |
return TEMPLATES.TemplateResponse("index.html", {"request": request})
567 |
568 |
569 |
async def health_check():
570 |
return {"status": "healthy"}
571 |
572 |
# Authentication Endpoints
573 |
574 |
@app.post("/token", response_model=Token)
575 |
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
576 |
user = authenticate_user(form_data.username, form_data.password)
577 |
if not user:
578 |
raise HTTPException(
579 |
580 |
detail="Incorrect username or password",
581 |
headers={"WWW-Authenticate": "Bearer"},
582 |
583 |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
584 |
access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
585 |
return {"access_token": access_token, "token_type": "bearer"}
586 |
587 |
def authenticate_user(username: str, password: str):
588 |
user = get_user(username)
589 |
if user and pwd_context.verify(password, user.hashed_password):
590 |
return {"username": user.username}
591 |
return None
592 |
593 |
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
594 |
to_encode = data.copy()
595 |
if expires_delta:
596 |
expire = datetime.utcnow() + expires_delta
597 |
598 |
expire = datetime.utcnow() + timedelta(minutes=15)
599 |
to_encode.update({"exp": expire})
600 |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
601 |
return encoded_jwt
602 |
603 |
class Token(BaseModel):
604 |
access_token: str
605 |
token_type: str
606 |
607 |
608 |
609 |
async def read_users_me(current_user: str = Depends(get_current_user)):
610 |
return {"username": current_user}
611 |
612 |
async def get_current_user(token: str = Depends(oauth2_scheme)):
613 |
credentials_exception = HTTPException(
614 |
615 |
detail="Could not validate credentials",
616 |
headers={"WWW-Authenticate": "Bearer"},
617 |
618 |
619 |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
620 |
username: str = payload.get("sub")
621 |
if username is None:
622 |
raise credentials_exception
623 |
token_data = {"username": username, "token": token}
624 |
except JWTError:
625 |
raise credentials_exception
626 |
user = get_user(username)
627 |
if user is None:
628 |
raise credentials_exception
629 |
return username
630 |
631 |
632 |
@app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
633 |
async def create_user(user: User):
634 |
635 |
hashed_password = pwd_context.hash(user.password)
636 |
new_user = {"username": user.username, "email": user.email, "hashed_password": hashed_password}
637 |
inserted_user = insert_user(new_user)
638 |
if inserted_user:
639 |
return User(**inserted_user)
640 |
641 |
raise HTTPException(status_code=500, detail="Failed to create user.")
642 |
except Exception as e:
643 |
logger.error(f"Error creating user: {e}")
644 |
raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
645 |
646 |
647 |
@app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
648 |
async def update_user_data(username: str, user: User):
649 |
650 |
hashed_password = pwd_context.hash(user.password)
651 |
updated_user_data = {"email": user.email, "hashed_password": hashed_password}
652 |
updated_user = update_user(username, updated_user_data)
653 |
if updated_user:
654 |
return User(**updated_user)
655 |
656 |
raise HTTPException(status_code=404, detail="User not found")
657 |
658 |
except Exception as e:
659 |
logger.error(f"Error updating user: {e}")
660 |
raise HTTPException(status_code=500, detail="Error updating user.")
661 |
662 |
663 |
664 |
@app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
665 |
async def delete_user_account(username: str):
666 |
667 |
deleted_user = delete_user(username)
668 |
if deleted_user:
669 |
return JSONResponse({"message": "User deleted successfully."}, status_code=200)
670 |
671 |
raise HTTPException(status_code=404, detail="User not found")
672 |
except Exception as e:
673 |
logger.error(f"Error deleting user: {e}")
674 |
raise HTTPException(status_code=500, detail="Error deleting user.")
675 |
676 |
677 |
@app.get("/users", dependencies=[Depends(get_current_user)])
678 |
async def get_all_users_route():
679 |
return get_all_users()
680 |
681 |
682 |
683 |
684 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
685 |
return JSONResponse(
686 |
687 |
content=json.dumps({"detail": exc.errors(), "body": exc.body}),
688 |
689 |
690 |
691 |
if __name__ == "__main__":
692 |
693 |
create_db_and_table() # Initialize database on startup
694 |
695 |
uvicorn.run("main:app", host="", port=7860, reload=True) # replace main with your filename