fashxp's picture
additional tasks
7bac21a
raw
history blame
8.18 kB
# -------------------------------------------------------------------
# This source file is available under the terms of the
# Pimcore Open Core License (POCL)
# Full copyright and license information is available in
# LICENSE.md which is distributed with this source code.
#
# @copyright Copyright (c) Pimcore GmbH (https://www.pimcore.com)
# @license Pimcore Open Core License (POCL)
# -------------------------------------------------------------------
import torch
from fastapi import FastAPI, Path, Request
import logging
import sys
from .translation_task import TranslationTaskService
from .classification import ClassificationTaskService
from .text_to_image import TextToImageTaskService
app = FastAPI(
title="Pimcore Local Inference Service",
description="This services allows HF inference provider compatible inference to models which are not available at HF inference providers.",
version="1.0.0"
)
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class StreamToLogger(object):
def __init__(self, logger, log_level):
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def write(self, buf):
for line in buf.rstrip().splitlines():
self.logger.log(self.log_level, line.rstrip())
def flush(self):
pass
sys.stdout = StreamToLogger(logger, logging.INFO)
sys.stderr = StreamToLogger(logger, logging.ERROR)
@app.get("/gpu_check")
async def gpu_check():
""" Check if a GPU is available """
gpu = 'GPU not available'
if torch.cuda.is_available():
gpu = 'GPU is available'
print("GPU is available")
else:
print("GPU is not available")
return {'success': True, 'gpu': gpu}
# =========================
# Translation Task
# =========================
@app.post(
"/translation/{model_name:path}/",
openapi_extra={
"requestBody": {
"content": {
"application/json": {
"example": {
"inputs": "Hello, world! foo bar",
"parameters": {"repetition_penalty": 1.6}
}
}
}
}
}
)
async def translate(
request: Request,
model_name: str = Path(
...,
description="The name of the translation model (e.g. Helsinki-NLP/opus-mt-en-de)",
example="Helsinki-NLP/opus-mt-en-de"
)
):
"""
Execute translation tasks.
Returns:
list: The translation result(s) as returned by the pipeline.
"""
translationTaskService = TranslationTaskService(logger)
return await translationTaskService.translate(request, model_name)
# =========================
# Zero-Shot Image Classification Task
# =========================
@app.post(
"/zero-shot-image-classification/{model_name:path}/",
openapi_extra={
"requestBody": {
"content": {
"application/json": {
"example": {
"inputs": "base64_encoded_image_string",
"parameters": {"candidate_labels": "green, yellow, blue, white, silver"}
}
}
}
}
}
)
async def zero_shot_image_classification(
request: Request,
model_name: str = Path(
...,
description="The name of the zero-shot classification model (e.g., openai/clip-vit-large-patch14-336)",
example="openai/clip-vit-large-patch14-336"
)
):
"""
Execute zero-shot image classification tasks.
Returns:
list: The classification result(s) as returned by the pipeline.
"""
zeroShotTask = ClassificationTaskService(logger, 'zero-shot-image-classification')
return await zeroShotTask.classify(request, model_name)
# =========================
# Image Classification Task
# =========================
@app.post(
"/image-classification/{model_name:path}/",
openapi_extra={
"requestBody": {
"content": {
"application/json": {
"example": {
"inputs": "base64_encoded_image_string"
}
}
}
}
}
)
async def image_classification(
request: Request,
model_name: str = Path(
...,
description="The name of the image classification model (e.g., pimcore/car-countries-classification)",
example="pimcore/car-countries-classification"
)
):
"""
Execute image classification tasks.
Returns:
list: The classification result(s) as returned by the pipeline.
"""
imageTask = ClassificationTaskService(logger, 'image-classification')
return await imageTask.classify(request, model_name)
# =========================
# Zero-Shot Text Classification Task
# =========================
@app.post(
"/zero-shot-text-classification/{model_name:path}/",
openapi_extra={
"requestBody": {
"content": {
"application/json": {
"example": {
"inputs": "text to classify",
"parameters": {"candidate_labels": "green, yellow, blue, white, silver"}
}
}
}
}
}
)
async def zero_shot_text_classification(
request: Request,
model_name: str = Path(
...,
description="The name of the zero-shot text classification model (e.g., facebook/bart-large-mnli)",
example="facebook/bart-large-mnli"
)
):
"""
Execute zero-shot text classification tasks.
Returns:
list: The classification result(s) as returned by the pipeline.
"""
zeroShotTask = ClassificationTaskService(logger, 'zero-shot-classification')
return await zeroShotTask.classify(request, model_name)
# =========================
# Text Classification Task
# =========================
@app.post(
"/text-classification/{model_name:path}/",
openapi_extra={
"requestBody": {
"content": {
"application/json": {
"example": {
"inputs": "text to classify"
}
}
}
}
}
)
async def text_classification(
request: Request,
model_name: str = Path(
...,
description="The name of the text classification model (e.g., pimcore/car-class-classification)",
example="pimcore/car-class-classification"
)
):
"""
Execute text classification tasks.
Returns:
list: The classification result(s) as returned by the pipeline.
"""
textTask = ClassificationTaskService(logger, 'text-classification')
return await textTask.classify(request, model_name)
# =========================
# Image to Text Task
# =========================
@app.post(
"/image-to-text/{model_name:path}/",
openapi_extra={
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"type": "object",
"properties": {
"image": {
"type": "string",
"format": "binary",
"description": "Image file to upload"
}
},
"required": ["image"]
}
}
}
}
}
)
async def image_to_text(
request: Request,
model_name: str = Path(
...,
description="The name of the image-to-text (e.g., Salesforce/blip-image-captioning-base)",
example="Salesforce/blip-image-captioning-base"
)
):
"""
Execute image-to-text tasks.
Returns:
list: The generated text as returned by the pipeline.
"""
imageToTextTask = TextToImageTaskService(logger)
return await imageToTextTask.extract(request, model_name)