Spaces:
Running
Running
# ------------------------------------------------------------------- | |
# This source file is available under the terms of the | |
# Pimcore Open Core License (POCL) | |
# Full copyright and license information is available in | |
# LICENSE.md which is distributed with this source code. | |
# | |
# @copyright Copyright (c) Pimcore GmbH (https://www.pimcore.com) | |
# @license Pimcore Open Core License (POCL) | |
# ------------------------------------------------------------------- | |
import torch | |
from fastapi import FastAPI, Path, Request | |
import logging | |
import sys | |
from .translation_task import TranslationTaskService | |
from .classification import ClassificationTaskService | |
from .text_to_image import TextToImageTaskService | |
app = FastAPI( | |
title="Pimcore Local Inference Service", | |
description="This services allows HF inference provider compatible inference to models which are not available at HF inference providers.", | |
version="1.0.0" | |
) | |
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s') | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class StreamToLogger(object): | |
def __init__(self, logger, log_level): | |
self.logger = logger | |
self.log_level = log_level | |
self.linebuf = '' | |
def write(self, buf): | |
for line in buf.rstrip().splitlines(): | |
self.logger.log(self.log_level, line.rstrip()) | |
def flush(self): | |
pass | |
sys.stdout = StreamToLogger(logger, logging.INFO) | |
sys.stderr = StreamToLogger(logger, logging.ERROR) | |
async def gpu_check(): | |
""" Check if a GPU is available """ | |
gpu = 'GPU not available' | |
if torch.cuda.is_available(): | |
gpu = 'GPU is available' | |
print("GPU is available") | |
else: | |
print("GPU is not available") | |
return {'success': True, 'gpu': gpu} | |
# ========================= | |
# Translation Task | |
# ========================= | |
async def translate( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the translation model (e.g. Helsinki-NLP/opus-mt-en-de)", | |
example="Helsinki-NLP/opus-mt-en-de" | |
) | |
): | |
""" | |
Execute translation tasks. | |
Returns: | |
list: The translation result(s) as returned by the pipeline. | |
""" | |
translationTaskService = TranslationTaskService(logger) | |
return await translationTaskService.translate(request, model_name) | |
# ========================= | |
# Zero-Shot Image Classification Task | |
# ========================= | |
async def zero_shot_image_classification( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the zero-shot classification model (e.g., openai/clip-vit-large-patch14-336)", | |
example="openai/clip-vit-large-patch14-336" | |
) | |
): | |
""" | |
Execute zero-shot image classification tasks. | |
Returns: | |
list: The classification result(s) as returned by the pipeline. | |
""" | |
zeroShotTask = ClassificationTaskService(logger, 'zero-shot-image-classification') | |
return await zeroShotTask.classify(request, model_name) | |
# ========================= | |
# Image Classification Task | |
# ========================= | |
async def image_classification( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the image classification model (e.g., pimcore/car-countries-classification)", | |
example="pimcore/car-countries-classification" | |
) | |
): | |
""" | |
Execute image classification tasks. | |
Returns: | |
list: The classification result(s) as returned by the pipeline. | |
""" | |
imageTask = ClassificationTaskService(logger, 'image-classification') | |
return await imageTask.classify(request, model_name) | |
# ========================= | |
# Zero-Shot Text Classification Task | |
# ========================= | |
async def zero_shot_text_classification( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the zero-shot text classification model (e.g., facebook/bart-large-mnli)", | |
example="facebook/bart-large-mnli" | |
) | |
): | |
""" | |
Execute zero-shot text classification tasks. | |
Returns: | |
list: The classification result(s) as returned by the pipeline. | |
""" | |
zeroShotTask = ClassificationTaskService(logger, 'zero-shot-classification') | |
return await zeroShotTask.classify(request, model_name) | |
# ========================= | |
# Text Classification Task | |
# ========================= | |
async def text_classification( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the text classification model (e.g., pimcore/car-class-classification)", | |
example="pimcore/car-class-classification" | |
) | |
): | |
""" | |
Execute text classification tasks. | |
Returns: | |
list: The classification result(s) as returned by the pipeline. | |
""" | |
textTask = ClassificationTaskService(logger, 'text-classification') | |
return await textTask.classify(request, model_name) | |
# ========================= | |
# Image to Text Task | |
# ========================= | |
async def image_to_text( | |
request: Request, | |
model_name: str = Path( | |
..., | |
description="The name of the image-to-text (e.g., Salesforce/blip-image-captioning-base)", | |
example="Salesforce/blip-image-captioning-base" | |
) | |
): | |
""" | |
Execute image-to-text tasks. | |
Returns: | |
list: The generated text as returned by the pipeline. | |
""" | |
imageToTextTask = TextToImageTaskService(logger) | |
return await imageToTextTask.extract(request, model_name) | |