{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
   "metadata": {},
   "source": [
    "## Copyright 2023 Google LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "outputs": [],
   "source": [
    "# Copyright 2023 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "540d8642-c203-471c-a66d-0d43aabb0706",
   "metadata": {},
   "source": [
    "# StyleAligned over SDXL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import StableDiffusionXLPipeline, DDIMScheduler\n",
    "import torch\n",
    "import mediapy\n",
    "import sa_handler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2f6f1e6-445f-47bc-b9db-0301caeb7490",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# init models\n",
    "\n",
    "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False,\n",
    "                              set_alpha_to_one=False)\n",
    "pipeline = StableDiffusionXLPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True,\n",
    "    scheduler=scheduler\n",
    ").to(\"cuda\")\n",
    "\n",
    "handler = sa_handler.Handler(pipeline)\n",
    "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
    "                                      share_layer_norm=False,\n",
    "                                      share_attention=True,\n",
    "                                      adain_queries=True,\n",
    "                                      adain_keys=True,\n",
    "                                      adain_values=False,\n",
    "                                     )\n",
    "\n",
    "handler.register(sa_args, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cca9256-0ce0-45c3-9cba-68c7eff1452f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# run StyleAligned\n",
    "\n",
    "sets_of_prompts = [\n",
    "  \"a toy train. macro photo. 3d game asset\",\n",
    "  \"a toy airplane. macro photo. 3d game asset\",\n",
    "  \"a toy bicycle. macro photo. 3d game asset\",\n",
    "  \"a toy car. macro photo. 3d game asset\",\n",
    "  \"a toy boat. macro photo. 3d game asset\",\n",
    "]\n",
    "images = pipeline(sets_of_prompts,).images\n",
    "mediapy.show_images(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d819ad6d-0c19-411f-ba97-199909f64805",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}