diff --git "a/notebooks/text_to_audio_sd.ipynb" "b/notebooks/text_to_audio_sd.ipynb"
new file mode 100644--- /dev/null
+++ "b/notebooks/text_to_audio_sd.ipynb"
@@ -0,0 +1,156 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"../\")\n",
+ "\n",
+ "import torch, os\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "from diffusers import StableDiffusionPipeline\n",
+ "from huggingface_hub import snapshot_download\n",
+ "from converter import load_wav, mel_spectrogram, normalize_spectrogram, denormalize_spectrogram, Generator, get_mel_spectrogram_from_audio\n",
+ "from utils import pad_spec, image_add_color, torch_to_pil, normalize, denormalize\n",
+ "from IPython.display import display, Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pretrained_model_name_or_path = \"auffusion/auffusion-full-no-adapter\"\n",
+ "dtype = torch.float16\n",
+ "device = \"cuda\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not os.path.isdir(pretrained_model_name_or_path):\n",
+ " pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder=\"vocoder\")\n",
+ "vocoder = vocoder.to(device=device, dtype=dtype)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=dtype)\n",
+ "pipe = pipe.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Examples\n",
+ "prompt = \"A kitten mewing for attention\"\n",
+ "seed = 42"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a5114b06c3224f9c8010f17ba20e97e9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/100 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Generation \n",
+ "generator = torch.Generator(device=device).manual_seed(seed)\n",
+ "\n",
+ "with torch.autocast(\"cuda\"):\n",
+ " output_spec = pipe(\n",
+ " prompt=prompt, num_inference_steps=100, generator=generator, height=256, width=1024, output_type=\"pt\"\n",
+ " ).images[0]\n",
+ "\n",
+ "\n",
+ "denorm_spec = denormalize_spectrogram(output_spec)\n",
+ "denorm_spec_audio = vocoder.inference(denorm_spec)\n",
+ "\n",
+ "display(Audio(denorm_spec_audio, rate=16000))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "TTA",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}