Spaces:
Runtime error
Runtime error
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline | |
import torch | |
from PIL import Image | |
import gradio as gr | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 | |
nsfw_pipe = pipeline("image-classification", | |
model= AutoModelForImageClassification.from_pretrained("carbon225/vit-base-patch16-224-hentai"), | |
feature_extractor=AutoFeatureExtractor.from_pretrained("carbon225/vit-base-patch16-224-hentai"), | |
device=device, | |
torch_dtype=dtype) | |
style_pipe = pipeline("image-classification", | |
model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_style"), | |
feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_style"), | |
device=device, | |
torch_dtype=dtype) | |
aesthetic_pipe = pipeline("image-classification", | |
model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_aesthetic"), | |
feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_aesthetic"), | |
device=device, | |
torch_dtype=dtype) | |
def predict(image, files=None): | |
images_paths = [image] | |
if not files == None: | |
images_paths = list(map(lambda x: x.name, files)) | |
pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths] | |
style = style_pipe(pil_images) | |
aesthetic = aesthetic_pipe(pil_images) | |
nsfw = nsfw_pipe(pil_images) | |
results = [ a + b + c for (a,b,c) in zip(style, aesthetic, nsfw)] | |
label_data = {} | |
if image is not None: | |
label_data = { row["label"]:row["score"] for row in results[0] } | |
return label_data, results | |
with gr.Blocks() as blocks: | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Image to test", type="filepath") | |
files = gr.File(label="Multipls Images", file_types=["image"], file_count="multiple") | |
with gr.Column(): | |
label = gr.Label(label="style") | |
results = gr.JSON(label="Results") | |
# gallery = gr.Gallery().style(grid=[2], height="auto") | |
btn = gr.Button("Run") | |
btn.click(fn=predict, inputs=[image, files], outputs=[label, results], api_name="inference") | |
blocks.queue() | |
blocks.launch(debug=True,inline=True) |