{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-09T09:44:30.641366Z",
     "start_time": "2024-12-09T09:44:11.789050Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import gradio as gr\n",
    "from diffusers import DiffusionPipeline\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from PIL import Image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ddf33e0d3abacc2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "#append current path\n",
    "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "643e49fd601daf8f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-09T09:44:35.790962Z",
     "start_time": "2024-12-09T09:44:35.779496Z"
    }
   },
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e03aae2a4e5676dd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-09T09:44:44.157412Z",
     "start_time": "2024-12-09T09:44:37.138452Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9df8347307674ba8afb0250e23109aa1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "83916bc68ff5d914",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-09T09:44:52.694399Z",
     "start_time": "2024-12-09T09:44:44.210695Z"
    }
   },
   "outputs": [],
   "source": [
    "from inference import get_lora_network, inference, get_validation_dataloader\n",
    "lora_map = {\n",
    "    \"None\": \"None\",\n",
    "    \"Andre Derain\": \"andre-derain_subset1\",\n",
    "    \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
    "    \"Andy Warhol\": \"andy_subset1\",\n",
    "    \"Walter Battiss\": \"walter-battiss_subset2\",\n",
    "    \"Camille Corot\": \"camille-corot_subset1\",\n",
    "    \"Claude Monet\": \"monet_subset2\",\n",
    "    \"Pablo Picasso\": \"picasso_subset1\",\n",
    "    \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
    "    \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
    "    \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
    "    \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
    "    \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
    "    \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
    "    \"Gustav Klimt\": \"klimt_subset3\",\n",
    "    \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
    "    \"Henri Matisse\": \"henri-matisse_subset1\",\n",
    "    \"Joan Miro\": \"joan-miro_subset2\",\n",
    "}\n",
    "\n",
    "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
    "    adapter_path = lora_map[adapter_choice]\n",
    "    if adapter_path not in [None, \"None\"]:\n",
    "        adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
    "\n",
    "    prompts = [prompt]*samples\n",
    "    infer_loader = get_validation_dataloader(prompts)\n",
    "    network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
    "    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
    "                            height=512, width=512, scales=[1.0],\n",
    "                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
    "                            start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
    "                            from_scratch=True)[0][1.0]\n",
    "    return pred_images\n",
    "\n",
    "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
    "    infer_loader = get_validation_dataloader(prompts, image)\n",
    "    network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
    "    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
    "                            height=512, width=512, scales=[0.,1.],\n",
    "                            save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
    "                            start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
    "                            from_scratch=False)\n",
    "    return pred_images\n",
    "\n",
    "# def infer(prompt, samples, steps, scale, seed):\n",
    "#     generator = torch.Generator(device=device).manual_seed(seed)\n",
    "#     images_list = pipe(  # type: ignore\n",
    "#         [prompt] * samples,\n",
    "#         num_inference_steps=steps,\n",
    "#         guidance_scale=scale,\n",
    "#         generator=generator,\n",
    "#     )\n",
    "#     images = []\n",
    "#     safe_image = Image.open(r\"data/unsafe.png\")\n",
    "#     print(images_list)\n",
    "#     for i, image in enumerate(images_list[\"images\"]):  # type: ignore\n",
    "#         if images_list[\"nsfw_content_detected\"][i]:  # type: ignore\n",
    "#             images.append(safe_image)\n",
    "#         else:\n",
    "#             images.append(image)\n",
    "#     return images\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "aa33e9d104023847",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-09T12:09:39.339583Z",
     "start_time": "2024-12-09T12:09:38.953936Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
      "Running on local URL:  http://127.0.0.1:7876\n",
      "Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train method: None\n",
      "Rank: 1, Alpha: 1\n",
      "create LoRA for U-Net: 0 modules.\n",
      "save dir: None\n",
      "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
      "  return F.conv2d(input, weight, bias, self.stride,\n",
      "\n",
      "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00,  6.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
     ]
    }
   ],
   "source": [
    "block = gr.Blocks()\n",
    "# Direct infer\n",
    "with block:\n",
    "    with gr.Group():\n",
    "        with gr.Row():\n",
    "            text = gr.Textbox(\n",
    "                label=\"Enter your prompt\",\n",
    "                max_lines=2,\n",
    "                placeholder=\"Enter your prompt\",\n",
    "                container=False,\n",
    "                value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
    "            )\n",
    "            \n",
    "\n",
    "            \n",
    "            btn = gr.Button(\"Run\", scale=0)\n",
    "        gallery = gr.Gallery(\n",
    "            label=\"Generated images\",\n",
    "            show_label=False,\n",
    "            elem_id=\"gallery\",\n",
    "            columns=[2],\n",
    "        )\n",
    "\n",
    "        advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
    "\n",
    "        with gr.Row(elem_id=\"advanced-options\"):\n",
    "            adapter_choice = gr.Dropdown(\n",
    "                label=\"Choose adapter\",\n",
    "                choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
    "                         \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
    "                         \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
    "                         \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
    "                         \"Henri Matisse\", \"Joan Miro\"\n",
    "                         ],\n",
    "                value=\"None\"\n",
    "            )\n",
    "            # print(adapter_choice[0])\n",
    "            # lora_path = lora_map[adapter_choice.value]\n",
    "            # if lora_path is not None:\n",
    "            #     lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
    "\n",
    "            samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
    "            steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
    "            scale = gr.Slider(\n",
    "                label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
    "            )\n",
    "            print(scale)\n",
    "            seed = gr.Slider(\n",
    "                label=\"Seed\",\n",
    "                minimum=0,\n",
    "                maximum=2147483647,\n",
    "                step=1,\n",
    "                randomize=True,\n",
    "            )\n",
    "\n",
    "        gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
    "        advanced_button.click(\n",
    "            None,\n",
    "            [],\n",
    "            text,\n",
    "        )\n",
    "\n",
    "\n",
    "block.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3239c12167a5f2cd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}