Spaces:
Runtime error
Runtime error
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline | |
import torch | |
from PIL import Image | |
import gradio as gr | |
import aiohttp | |
import asyncio | |
from io import BytesIO | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 | |
nsfw_pipe = pipeline("image-classification", | |
model=AutoModelForImageClassification.from_pretrained( | |
"carbon225/vit-base-patch16-224-hentai"), | |
feature_extractor=AutoFeatureExtractor.from_pretrained( | |
"carbon225/vit-base-patch16-224-hentai"), | |
device=device, | |
torch_dtype=dtype) | |
style_pipe = pipeline("image-classification", | |
model=AutoModelForImageClassification.from_pretrained( | |
"cafeai/cafe_style"), | |
feature_extractor=AutoFeatureExtractor.from_pretrained( | |
"cafeai/cafe_style"), | |
device=device, | |
torch_dtype=dtype) | |
aesthetic_pipe = pipeline("image-classification", | |
model=AutoModelForImageClassification.from_pretrained( | |
"cafeai/cafe_aesthetic"), | |
feature_extractor=AutoFeatureExtractor.from_pretrained( | |
"cafeai/cafe_aesthetic"), | |
device=device, | |
torch_dtype=dtype) | |
async def fetch_image(session, image_url): | |
print(f"fetching image {image_url}") | |
async with session.get(image_url) as response: | |
if response.status == 200 and response.headers['content-type'].startswith('image'): | |
pil_image = Image.open(BytesIO(await response.read())).convert('RGB') | |
# resize image proportional | |
# image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
return pil_image | |
return None | |
async def fetch_images(image_urls): | |
async with aiohttp.ClientSession() as session: | |
tasks = [asyncio.ensure_future(fetch_image( | |
session, image_url)) for image_url in image_urls] | |
return await asyncio.gather(*tasks) | |
async def predict(json=None, enable_gallery=True, image=None, files=None): | |
print(json) | |
if image or files: | |
if image is not None: | |
images_paths = [image] | |
elif files is not None: | |
images_paths = list(map(lambda x: x.name, files)) | |
pil_images = [Image.open(image_path).convert("RGB") | |
for image_path in images_paths] | |
elif json is not None: | |
pil_images = await fetch_images(json["urls"]) | |
style = style_pipe(pil_images) | |
aesthetic = aesthetic_pipe(pil_images) | |
nsfw = nsfw_pipe(pil_images) | |
results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)] | |
label_data = {} | |
if image is not None: | |
label_data = {row["label"]: row["score"] for row in results[0]} | |
return results, label_data, pil_images if enable_gallery else None | |
with gr.Blocks() as blocks: | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Image to test", type="filepath") | |
files = gr.File(label="Multipls Images", file_types=[ | |
"image"], file_count="multiple") | |
enable_gallery = gr.Checkbox(label="Enable Gallery", value=True) | |
json = gr.JSON(label="Results", value={"urls": [ | |
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg', | |
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg', | |
'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']}) | |
with gr.Column(): | |
label = gr.Label(label="style") | |
results = gr.JSON(label="Results") | |
gallery = gr.Gallery().style(grid=[2], height="auto") | |
btn = gr.Button("Run") | |
btn.click(fn=predict, inputs=[json, enable_gallery, image, files], | |
outputs=[results, label, gallery], api_name="inference") | |
blocks.queue() | |
blocks.launch(debug=True, inline=True) | |