Spaces:
Runtime error
Runtime error
File size: 5,708 Bytes
631e673 156d0fd 631e673 156d0fd 631e673 659a5e1 631e673 659a5e1 631e673 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
{
"cells": [
{
"metadata": {},
"execution_count": null,
"outputs": [],
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"mount_drive = True #@param {type:\"boolean\"}\n",
"if mount_drive:\n",
" drive.mount('/content/drive')\n",
"\n",
"requirements_txt = \"git+https://github.com/ArthurZucker/transformers.git@jukebox\\naccelerate\\nbitsandbytes==0.31.8\\ngradio\"\n",
"\n",
"# Save the requirements.txt file\n",
"with open('requirements.txt', 'w') as f:\n",
" f.write(requirements_txt)\n",
"\n",
"# Install the dependencies\n",
"%pip install -r requirements.txt"
]
},
{
"metadata": {},
"execution_count": null,
"outputs": [],
"cell_type": "code",
"source": [
"# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI\n",
"\n",
"import sys\n",
"\n",
"import torch as t\n",
"from transformers import JukeboxVQVAE\n",
"import gradio as gr\n",
"\n",
"model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']\n",
"\n",
"if 'google.colab' in sys.modules:\n",
"\n",
" cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:\"string\"}\n",
" # Connect to your Google Drive\n",
" from google.colab import drive\n",
" drive.mount('/content/drive')\n",
"\n",
"else:\n",
"\n",
" cache_path = '~/.cache/'\n",
"\n",
"class Convert:\n",
"\n",
" class TokenList:\n",
"\n",
" def to_tokens_file(tokens_list):\n",
" # temporary random file name\n",
" filename = f\"tmp/{t.randint(0, 1000000)}.jt\"\n",
" t.save(validate_tokens_list(tokens_list), filename)\n",
" return filename\n",
"\n",
" def to_audio(tokens_list):\n",
" return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)\n",
" # TODO: Implement converting other levels besides 2\n",
"\n",
" class TokensFile:\n",
"\n",
" def to_tokens_list(file):\n",
" return validate_tokens_list(t.load(file))\n",
"\n",
" def to_audio(file):\n",
" return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))\n",
"\n",
" class Audio:\n",
"\n",
" def to_tokens_list(audio):\n",
" return model.encode(audio.unsqueeze(0), start_level=2)\n",
" # (TODO: Generated by copilot, check if it works)\n",
"\n",
" def to_tokens_file(audio):\n",
" return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))\n",
"\n",
"def init():\n",
" global model\n",
"\n",
" try:\n",
" model\n",
" print(\"Model already initialized\")\n",
" except NameError:\n",
" model = JukeboxVQVAE.from_pretrained(\n",
" model_id,\n",
" torch_dtype = t.float16,\n",
" cache_dir = f\"{cache_path}/jukebox/models\"\n",
" )\n",
"\n",
"def validate_tokens_list(tokens_list):\n",
" # Make sure that:\n",
" # - tokens_list is a list of exactly 3 torch tensors\n",
" assert len(tokens_list) == 3, \"Invalid file format: expecting a list of 3 tensors\"\n",
"\n",
" # - each has the same number of dimensions\n",
" assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), \"Invalid file format: each tensor in the list should have the same number of dimensions\"\n",
"\n",
" # - the shape along dimension 0 is the same\n",
" assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], \"Invalid file format: the shape along dimension 0 should be the same for all tensors in the list\"\n",
"\n",
" # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2\n",
" assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], \"Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2\"\n",
"\n",
" return tokens_list\n",
"\n",
"\n",
"with gr.Blocks() as ui:\n",
"\n",
" # File input to upload or download the music tokens file\n",
" tokens = gr.File(label='music_tokens_file')\n",
"\n",
" # Audio output to play or upload the generated audio\n",
" audio = gr.Audio(label='audio')\n",
"\n",
" # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)\n",
" gr.Button(\"Convert tokens to audio\", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)\n",
" gr.Button(\"Convert audio to tokens\", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)\n",
"\n",
"if __name__ == '__main__':\n",
" init()\n",
" ui.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.7.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
} |