File size: 2,480 Bytes
6974603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caeb1f4
 
4fecf45
caeb1f4
 
 
 
 
 
 
c9267e5
 
 
 
6974603
 
caeb1f4
 
 
 
 
 
 
 
 
 
 
6974603
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch 
from PIL import Image
import gradio as gr

device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification", 
                          model= AutoModelForImageClassification.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
                          device=device,
                          torch_dtype=dtype)


style_pipe = pipeline("image-classification", 
                          model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_style"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_style"),
                          device=device,
                          torch_dtype=dtype)

aesthetic_pipe = pipeline("image-classification", 
                          model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_aesthetic"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_aesthetic"),
                          device=device,
                          torch_dtype=dtype)

def predict(image, files=None):
    images_paths = [image]
    if not files == None: 
        images_paths = list(map(lambda x: x.name, files))
    pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths]
    
    style = style_pipe(pil_images)
    aesthetic = aesthetic_pipe(pil_images)
    nsfw = nsfw_pipe(pil_images)
    results = [ a + b + c  for (a,b,c) in zip(style, aesthetic, nsfw)]
    label_data = {}
    if image is not None:
        label_data = { row["label"]:row["score"] for row in results[0] }
    
    return label_data, results

with gr.Blocks() as blocks:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Image to test", type="filepath")
            files = gr.File(label="Multipls Images", file_types=["image"], file_count="multiple")
        with gr.Column():
            label = gr.Label(label="style")
            results = gr.JSON(label="Results")
            # gallery = gr.Gallery().style(grid=[2], height="auto")
    btn = gr.Button("Run")
    
    btn.click(fn=predict, inputs=[image, files], outputs=[label, results], api_name="inference")

blocks.queue()
blocks.launch(debug=True,inline=True)