|
from asyncio import gather |
|
from base64 import b64decode |
|
from binascii import Error as BinasciiError |
|
from contextlib import asynccontextmanager |
|
from io import BytesIO |
|
from logging import Formatter, INFO, StreamHandler, getLogger |
|
from pathlib import Path |
|
from random import randint |
|
from typing import AsyncGenerator |
|
from uuid import UUID |
|
|
|
from PIL.Image import open as image_open |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
from httpx import AsyncClient |
|
from starlette.responses import Response |
|
|
|
logger = getLogger('NVIDIA_VLM_API') |
|
logger.setLevel(INFO) |
|
handler = StreamHandler() |
|
handler.setLevel(INFO) |
|
formatter = Formatter('%(asctime)s | %(levelname)s : %(message)s', datefmt='%d.%m.%Y %H:%M:%S') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
|
|
logger.info('инициализация приложения...') |
|
|
|
INVOKE_URLS = [ |
|
'https://ai.api.nvidia.com/v1/vlm/microsoft/phi-3-vision-128k-instruct', |
|
'https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b', |
|
'https://ai.api.nvidia.com/v1/vlm/nvidia/vila' |
|
] |
|
|
|
ASSETS_URL = 'https://api.nvcf.nvidia.com/v2/nvcf/assets' |
|
|
|
|
|
def get_extension(filename: str) -> str: |
|
return Path(filename).suffix[1:].lower() |
|
|
|
|
|
async def upload_asset(client: AsyncClient, media_file_bytes: bytes, description: str, api_key: str) -> UUID: |
|
headers = { |
|
'Authorization': f'Bearer {api_key}', |
|
'Content-Type': 'application/json', |
|
'accept': 'application/json', |
|
} |
|
|
|
authorize_response = await client.post( |
|
ASSETS_URL, |
|
headers=headers, |
|
json={'contentType': 'image/jpeg', 'description': description}, |
|
timeout=30, |
|
) |
|
authorize_response.raise_for_status() |
|
authorize_res = authorize_response.json() |
|
response = await client.put( |
|
authorize_res.get('uploadUrl'), |
|
content=media_file_bytes, |
|
headers={'x-amz-meta-nvcf-asset-description': description, 'content-type': 'image/jpeg'}, |
|
timeout=300, |
|
) |
|
|
|
response.raise_for_status() |
|
return UUID(authorize_res.get('assetId')) |
|
|
|
|
|
async def delete_asset(client: AsyncClient, asset_id: UUID, api_key: str) -> None: |
|
headers = {'Authorization': f'Bearer {api_key}'} |
|
response = await client.delete(f'{ASSETS_URL}/{asset_id}', headers=headers, timeout=30) |
|
response.raise_for_status() |
|
|
|
|
|
async def chat_with_media_nvcf(infer_url: str, media_file_bytes: bytes, query: str, api_key: str) -> str | None: |
|
try: |
|
async with AsyncClient(follow_redirects=True, timeout=45) as client: |
|
asset_list = [] |
|
asset_id = await upload_asset(client, media_file_bytes, 'Reference media file', api_key) |
|
asset_list.append(str(asset_id)) |
|
media_content = f'<img src="data:image/jpeg;asset_id,{asset_id}" />' |
|
asset_seq = ','.join(asset_list) |
|
headers = { |
|
'Authorization': f'Bearer {api_key}', |
|
'Content-Type': 'application/json', |
|
'NVCF-INPUT-ASSET-REFERENCES': asset_seq, |
|
'NVCF-FUNCTION-ASSET-IDS': asset_seq, |
|
'Accept': 'application/json', |
|
} |
|
|
|
payload = { |
|
'max_tokens': 1024, |
|
'temperature': 0.65, |
|
'top_p': 0.95, |
|
'seed': randint(0, 999999999), |
|
'messages': [{'role': 'user', 'content': f'{query} {media_content}'}], |
|
'stream': False, |
|
"model": infer_url.split('/v1/vlm/')[-1] |
|
} |
|
|
|
response = await client.post(infer_url, headers=headers, json=payload) |
|
response_json = response.json() |
|
answer = response_json.get('choices', [{}])[0].get('message', {}).get('content', None) |
|
for asset_id in asset_list: |
|
await delete_asset(client, UUID(asset_id), api_key) |
|
return answer |
|
except Exception as exc: |
|
print(exc) |
|
return None |
|
|
|
|
|
def base64_to_jpeg_bytes(base64_str: str) -> bytes: |
|
try: |
|
if ',' not in base64_str: |
|
raise ValueError('недопустимый формат строки base64') |
|
base64_data = base64_str.split(',', 1)[1] |
|
binary_data = b64decode(base64_data) |
|
with image_open(BytesIO(binary_data)) as img: |
|
with BytesIO() as jpeg_bytes: |
|
img.convert('RGB').save(jpeg_bytes, format='JPEG', quality=90, optimize=True) |
|
return jpeg_bytes.getvalue() |
|
except (BinasciiError, OSError) as e: |
|
raise ValueError('данные не являются корректным изображением') from e |
|
|
|
|
|
async def get_captions(image_base64_str: str, query: str, api_key: str) -> dict[str, str]: |
|
media_file_bytes = base64_to_jpeg_bytes(image_base64_str) |
|
tasks = [chat_with_media_nvcf(url, media_file_bytes, query, api_key) for url in INVOKE_URLS] |
|
results = await gather(*tasks) |
|
return dict(zip((url.split('/v1/vlm/')[-1] for url in INVOKE_URLS), results)) |
|
|
|
|
|
@asynccontextmanager |
|
async def app_lifespan(_) -> AsyncGenerator: |
|
logger.info('запуск приложения') |
|
try: |
|
logger.info('старт API') |
|
yield |
|
finally: |
|
logger.info('приложение завершено') |
|
|
|
|
|
app = FastAPI(lifespan=app_lifespan, title='NVIDIA_VLM_API') |
|
|
|
banned_endpoints = [ |
|
'/openapi.json', |
|
'/docs', |
|
'/docs/oauth2-redirect', |
|
'swagger_ui_redirect', |
|
'/redoc', |
|
] |
|
|
|
|
|
@app.middleware('http') |
|
async def block_banned_endpoints(request: Request, call_next): |
|
logger.debug(f'получен запрос: {request.url.path}') |
|
if request.url.path in banned_endpoints: |
|
logger.warning(f'запрещенный endpoint: {request.url.path}') |
|
return Response(status_code=403) |
|
response = await call_next(request) |
|
return response |
|
|
|
|
|
@app.post('/v1/describe') |
|
async def describe_v1(request: Request): |
|
logger.info('запрос `describe_v1`') |
|
body = await request.json() |
|
headers = request.headers |
|
authorization: str = headers.get('Authorization') or headers.get('authorization') |
|
nvapi_key = authorization.removeprefix('Bearer ').strip() |
|
if not authorization or not nvapi_key: |
|
return JSONResponse({'caption': 'в запросе нужно передать заголовок `Authorization: Bearer <NVAPI_KEY>`'}, status_code=401) |
|
|
|
content_text = '' |
|
image_data = '' |
|
|
|
messages = body.get('messages', []) |
|
for message in messages: |
|
role = message.get('role') |
|
content = message.get('content') |
|
|
|
if role in ['system', 'user']: |
|
if isinstance(content, str): |
|
content_text += content + ' ' |
|
elif isinstance(content, list): |
|
for item in content: |
|
if item.get('type') == 'text': |
|
content_text += item.get('text', '') + ' ' |
|
elif item.get('type') == 'image_url': |
|
image_url = item.get('image_url', {}) |
|
url = image_url.get('url') |
|
if url and url.startswith('data:image/'): |
|
image_data = url |
|
image_data, content_text = image_data.strip(), content_text.strip() |
|
|
|
if not content_text or not image_data: |
|
return JSONResponse({'caption': 'изображение должно быть передано как строка base64 `data:image/jpeg;base64,{base64_img}` а также текст'}, status_code=400) |
|
try: |
|
return JSONResponse(await get_captions(image_data, content_text, nvapi_key), status_code=200) |
|
except Exception as e: |
|
return JSONResponse({'caption': str(e)}, status_code=500) |
|
|
|
|
|
@app.get('/') |
|
async def root(): |
|
return HTMLResponse('ну пролапс, ну и что', status_code=200) |
|
|
|
|
|
if __name__ == '__main__': |
|
from uvicorn import run as uvicorn_run |
|
|
|
logger.info('запуск сервера uvicorn') |
|
uvicorn_run(app, host='0.0.0.0', port=7860) |
|
|