File size: 5,684 Bytes
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659a5e1
 
 
 
 
 
 
 
 
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659a5e1
 
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
{
  "cells": [
    {
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "mount_drive = True #@param {type:\"boolean\"}\n",
        "if mount_drive:\n",
        "  drive.mount('/content/drive')\n",
        "\n",
        "requirements_txt = \"git+https://github.com/ArthurZucker/transformers.git@jukebox\\naccelerate\\nbitsandbytes==0.31.8\\ngradio\"\n",
        "\n",
        "# Save the requirements.txt file\n",
        "with open('requirements.txt', 'w') as f:\n",
        "  f.write(requirements_txt)\n",
        "\n",
        "# Install the dependencies\n",
        "%pip install -r requirements.txt"
      ]
    },
    {
      "metadata": {},
      "execution_count": null,
      "outputs": [],
      "cell_type": "code",
      "source": [
        "# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI\n",
        "\n",
        "from transformers import JukeboxVQVAE\n",
        "\n",
        "import gradio as gr\n",
        "import torch as t\n",
        "\n",
        "model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "\n",
        "  cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:\"string\"}\n",
        "  # Connect to your Google Drive\n",
        "  from google.colab import drive\n",
        "  drive.mount('/content/drive')\n",
        "\n",
        "else:\n",
        "\n",
        "  cache_path = '~/.cache/'\n",
        "\n",
        "class Convert:\n",
        "\n",
        "  class TokenList:\n",
        "\n",
        "    def to_tokens_file(tokens_list):\n",
        "      # temporary random file name\n",
        "      filename = f\"tmp/{t.randint(0, 1000000)}.jt\"\n",
        "      t.save(validate_tokens_list(tokens_list), filename)\n",
        "      return filename\n",
        "\n",
        "    def to_audio(tokens_list):\n",
        "      return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)\n",
        "      # TODO: Implement converting other levels besides 2\n",
        "\n",
        "  class TokensFile:\n",
        "\n",
        "    def to_tokens_list(file):\n",
        "      return validate_tokens_list(t.load(file))\n",
        "\n",
        "    def to_audio(file):\n",
        "      return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))\n",
        "\n",
        "  class Audio:\n",
        "\n",
        "    def to_tokens_list(audio):\n",
        "      return model.encode(audio.unsqueeze(0), start_level=2)\n",
        "      # (TODO: Generated by copilot, check if it works)\n",
        "\n",
        "    def to_tokens_file(audio):\n",
        "      return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))\n",
        "\n",
        "def init():\n",
        "  global model\n",
        "\n",
        "  try:\n",
        "    model\n",
        "    print(\"Model already initialized\")\n",
        "  except NameError:\n",
        "    model = JukeboxVQVAE.from_pretrained(\n",
        "      model_id,\n",
        "      torch_dtype = t.float16,\n",
        "      cache_dir = f\"{cache_path}/jukebox/models\"\n",
        "    )\n",
        "\n",
        "def validate_tokens_list(tokens_list):\n",
        "  # Make sure that:\n",
        "  # - tokens_list is a list of exactly 3 torch tensors\n",
        "  assert len(tokens_list) == 3, \"Invalid file format: expecting a list of 3 tensors\"\n",
        "\n",
        "  # - each has the same number of dimensions\n",
        "  assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), \"Invalid file format: each tensor in the list should have the same number of dimensions\"\n",
        "\n",
        "  # - the shape along dimension 0 is the same\n",
        "  assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], \"Invalid file format: the shape along dimension 0 should be the same for all tensors in the list\"\n",
        "\n",
        "  # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2\n",
        "  assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], \"Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2\"\n",
        "\n",
        "  return tokens_list\n",
        "\n",
        "\n",
        "with gr.Blocks() as ui:\n",
        "\n",
        "  # File input to upload or download the music tokens file\n",
        "  tokens = gr.File(label='music_tokens_file')\n",
        "\n",
        "  # Audio output to play or upload the generated audio\n",
        "  audio = gr.Audio(label='audio')\n",
        "\n",
        "  # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)\n",
        "  gr.Button(\"Convert tokens to audio\", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)\n",
        "  gr.Button(\"Convert audio to tokens\", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)\n",
        "\n",
        "if __name__ == '__main__':\n",
        "  init()\n",
        "  ui.launch()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.7.5"
    },
    "orig_nbformat": 4
  },
  "nbformat": 4,
  "nbformat_minor": 2
}