{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:30.641366Z", "start_time": "2024-12-09T09:44:11.789050Z" } }, "outputs": [], "source": [ "import os\n", "\n", "import gradio as gr\n", "from diffusers import DiffusionPipeline\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from PIL import Image\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ddf33e0d3abacc2c", "metadata": {}, "outputs": [], "source": [ "import sys\n", "#append current path\n", "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "643e49fd601daf8f", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:35.790962Z", "start_time": "2024-12-09T09:44:35.779496Z" } }, "outputs": [], "source": [ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "e03aae2a4e5676dd", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:44.157412Z", "start_time": "2024-12-09T09:44:37.138452Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9df8347307674ba8afb0250e23109aa1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "83916bc68ff5d914", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:52.694399Z", "start_time": "2024-12-09T09:44:44.210695Z" } }, "outputs": [], "source": [ "from inference import get_lora_network, inference, get_validation_dataloader\n", "lora_map = {\n", " \"None\": \"None\",\n", " \"Andre Derain\": \"andre-derain_subset1\",\n", " \"Vincent van Gogh\": \"van_gogh_subset1\",\n", " \"Andy Warhol\": \"andy_subset1\",\n", " \"Walter Battiss\": \"walter-battiss_subset2\",\n", " \"Camille Corot\": \"camille-corot_subset1\",\n", " \"Claude Monet\": \"monet_subset2\",\n", " \"Pablo Picasso\": \"picasso_subset1\",\n", " \"Jackson Pollock\": \"jackson-pollock_subset1\",\n", " \"Gerhard Richter\": \"gerhard-richter_subset1\",\n", " \"M.C. Escher\": \"m.c.-escher_subset1\",\n", " \"Albert Gleizes\": \"albert-gleizes_subset1\",\n", " \"Hokusai\": \"katsushika-hokusai_subset1\",\n", " \"Wassily Kandinsky\": \"kandinsky_subset1\",\n", " \"Gustav Klimt\": \"klimt_subset3\",\n", " \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n", " \"Henri Matisse\": \"henri-matisse_subset1\",\n", " \"Joan Miro\": \"joan-miro_subset2\",\n", "}\n", "\n", "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n", " adapter_path = lora_map[adapter_choice]\n", " if adapter_path not in [None, \"None\"]:\n", " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n", "\n", " prompts = [prompt]*samples\n", " infer_loader = get_validation_dataloader(prompts)\n", " network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n", " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n", " height=512, width=512, scales=[1.0],\n", " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n", " start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n", " from_scratch=True)[0][1.0]\n", " return pred_images\n", "\n", "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n", " infer_loader = get_validation_dataloader(prompts, image)\n", " network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n", " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n", " height=512, width=512, scales=[0.,1.],\n", " save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n", " start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n", " from_scratch=False)\n", " return pred_images\n", "\n", "# def infer(prompt, samples, steps, scale, seed):\n", "# generator = torch.Generator(device=device).manual_seed(seed)\n", "# images_list = pipe( # type: ignore\n", "# [prompt] * samples,\n", "# num_inference_steps=steps,\n", "# guidance_scale=scale,\n", "# generator=generator,\n", "# )\n", "# images = []\n", "# safe_image = Image.open(r\"data/unsafe.png\")\n", "# print(images_list)\n", "# for i, image in enumerate(images_list[\"images\"]): # type: ignore\n", "# if images_list[\"nsfw_content_detected\"][i]: # type: ignore\n", "# images.append(safe_image)\n", "# else:\n", "# images.append(image)\n", "# return images\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "aa33e9d104023847", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T12:09:39.339583Z", "start_time": "2024-12-09T12:09:38.953936Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n", "Running on local URL: http://127.0.0.1:7876\n", "Running on public URL: https://be7cce8fec75395c82.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "Train method: None\n", "Rank: 1, Alpha: 1\n", "create LoRA for U-Net: 0 modules.\n", "save dir: None\n", "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n", " return F.conv2d(input, weight, bias, self.stride,\n", "\n", "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n" ] } ], "source": [ "block = gr.Blocks()\n", "# Direct infer\n", "with block:\n", " with gr.Group():\n", " with gr.Row():\n", " text = gr.Textbox(\n", " label=\"Enter your prompt\",\n", " max_lines=2,\n", " placeholder=\"Enter your prompt\",\n", " container=False,\n", " value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n", " )\n", " \n", "\n", " \n", " btn = gr.Button(\"Run\", scale=0)\n", " gallery = gr.Gallery(\n", " label=\"Generated images\",\n", " show_label=False,\n", " elem_id=\"gallery\",\n", " columns=[2],\n", " )\n", "\n", " advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n", "\n", " with gr.Row(elem_id=\"advanced-options\"):\n", " adapter_choice = gr.Dropdown(\n", " label=\"Choose adapter\",\n", " choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n", " \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n", " \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n", " \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n", " \"Henri Matisse\", \"Joan Miro\"\n", " ],\n", " value=\"None\"\n", " )\n", " # print(adapter_choice[0])\n", " # lora_path = lora_map[adapter_choice.value]\n", " # if lora_path is not None:\n", " # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n", "\n", " samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n", " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n", " scale = gr.Slider(\n", " label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n", " )\n", " print(scale)\n", " seed = gr.Slider(\n", " label=\"Seed\",\n", " minimum=0,\n", " maximum=2147483647,\n", " step=1,\n", " randomize=True,\n", " )\n", "\n", " gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n", " advanced_button.click(\n", " None,\n", " [],\n", " text,\n", " )\n", "\n", "\n", "block.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "3239c12167a5f2cd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }