radames's picture
fix
c9267e5
raw
history blame
2.48 kB
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch
from PIL import Image
import gradio as gr
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification",
model= AutoModelForImageClassification.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
feature_extractor=AutoFeatureExtractor.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
device=device,
torch_dtype=dtype)
style_pipe = pipeline("image-classification",
model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_style"),
feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_style"),
device=device,
torch_dtype=dtype)
aesthetic_pipe = pipeline("image-classification",
model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_aesthetic"),
feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_aesthetic"),
device=device,
torch_dtype=dtype)
def predict(image, files=None):
images_paths = [image]
if not files == None:
images_paths = list(map(lambda x: x.name, files))
pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths]
style = style_pipe(pil_images)
aesthetic = aesthetic_pipe(pil_images)
nsfw = nsfw_pipe(pil_images)
results = [ a + b + c for (a,b,c) in zip(style, aesthetic, nsfw)]
label_data = {}
if image is not None:
label_data = { row["label"]:row["score"] for row in results[0] }
return label_data, results
with gr.Blocks() as blocks:
with gr.Row():
with gr.Column():
image = gr.Image(label="Image to test", type="filepath")
files = gr.File(label="Multipls Images", file_types=["image"], file_count="multiple")
with gr.Column():
label = gr.Label(label="style")
results = gr.JSON(label="Results")
# gallery = gr.Gallery().style(grid=[2], height="auto")
btn = gr.Button("Run")
btn.click(fn=predict, inputs=[image, files], outputs=[label, results], api_name="inference")
blocks.queue()
blocks.launch(debug=True,inline=True)