diff --git "a/imdb.ipynb" "b/imdb.ipynb"
new file mode 100644--- /dev/null
+++ "b/imdb.ipynb"
@@ -0,0 +1,7664 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ODbwgRIAxdCh"
+ },
+ "source": [
+ "# Gemma Scope Tutorial\n",
+ "\n",
+ "This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a \"microscope\" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!\n",
+ "\n",
+ "**Learn more:**\n",
+ "* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).\n",
+ "* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).\n",
+ "* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gemma Scope Tutorial\n",
+ "\n",
+ "This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a \"microscope\" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!\n",
+ "\n",
+ "**Learn more:**\n",
+ "* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).\n",
+ "* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).\n",
+ "* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rB2BasaDOm_t"
+ },
+ "source": [
+ "\n",
+ "For illustrative purposes, we begin with a lightweight tutorial that uses as few libraries as possible to outline how Gemma Scope works, and what Sparse Autoencoders are doing. This is deliberately a fairly minimalist tutorial, designed to make clear what is actually going on, but does not model research best practices.\n",
+ "\n",
+ "For any serious research with Gemma Scope, **we recommend using the [SAELens](https://jbloomaus.github.io/SAELens/) and [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) libraries**, see [this tutorial](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb) on how to use [SAELens](https://jbloomaus.github.io/SAELens/) in practice.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RvDc2KCO9DYS"
+ },
+ "source": [
+ "## Loading the Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fB9bB3mJ8R1H"
+ },
+ "source": [
+ "First, let's load the model:\n",
+ "\n",
+ "For simplicity we do this straight from [HuggingFace transformers](https://huggingface.co/docs/transformers/en/index), rather than using an interpretability focused library like [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) or [nnsight](https://nnsight.net/)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "id": "nOBcV4om7mrT"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ffea7bc53a17446cacb9a35ae3adc0a1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some parameters are on the meta device device because they were offloaded to the cpu.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
+ "from huggingface_hub import hf_hub_download, notebook_login\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "\n",
+ "torch.set_grad_enabled(False) # avoid blowing up mem\n",
+ "\n",
+ "params = {\n",
+ " \"model_name\" : \"google/gemma-2-9b-it\",\n",
+ " \"width\" : \"16k\",\n",
+ " \"layer\" : 31,\n",
+ " \"l0\" : 76,\n",
+ " \"sae_repo_id\": \"google/gemma-scope-9b-it-res\",\n",
+ " \"filename\" : \"layer_31/width_16k/average_l0_76/params.npz\"\n",
+ "}\n",
+ "\n",
+ "# params = {\n",
+ "# \"model_name\" : \"google/gemma-2-2b\",\n",
+ "# \"width\" : \"16k\",\n",
+ "# \"layer\" : 23,\n",
+ "# \"l0\" : 74,\n",
+ "# \"sae_repo_id\": \"google/gemma-scope-2b-pt-res\",\n",
+ "# \"filename\" : \"layer_23/width_16k/average_l0_74/params.npz\"\n",
+ "# }\n",
+ "\n",
+ "model_name = params[\"model_name\"]\n",
+ "width = params[\"width\"]\n",
+ "layer = params[\"layer\"]\n",
+ "l0 = params[\"l0\"]\n",
+ "sae_repo_id = params[\"sae_repo_id\"]\n",
+ "filename = params[\"filename\"]\n",
+ "\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " device_map='auto',\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "\n",
+ "filename = f\"layer_{layer}/width_{width}/average_l0_{l0}/params.npz\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8Q6sQSaAN7T7"
+ },
+ "source": [
+ "We load Gemma 2 2B, the smallest model that Gemma Scope works for. We load the base model, not the chat model, since that's where our SAEs are trained. Though the SAEs seem to transfer OK to these models. First, you'll need to authenticate with huggingface in order to download the model weights"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MZkgvglU9GdW"
+ },
+ "source": [
+ "Now we've loaded the model, let's try running it! We give it the prompt \"Would you be able to travel through time using a wormhole?\" and print the generated output"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qZECwzKi9dGv"
+ },
+ "source": [
+ "## Loading a Sparse Autoencoder"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wQSE4K5KVmGY"
+ },
+ "source": [
+ "OK, so we have got Gemma 2 loaded and can sample from it to get sensible stuff. Now, let's load one of our SAEs.\n",
+ "\n",
+ "GemmaScope actually contains over four hundred SAEs, but for now we'll just load one on the residual stream at the end of layer 20 (of 26, note that layers start at 0 so this is the 21st layer. This is a fairly late layer, so the model should have time to find more abstract concepts!).\n",
+ "\n",
+ "See [the final section](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp?authuser=2#scrollTo=E7zjkVseLSPp) for more information on how to load all the other SAEs in Gemma Scope\n",
+ "\n",
+ "What is the residual stream?
\n",
+ "\n",
+ "Transformers have skip connections, which means that the output of each block is the output of each sublayer *plus* the input to the block. This means that each sublayer (attention or MLP) actually only has a fairly small effect on the output of the block, since most of it comes from all the earlier layers. We call the output of a block (including skip connections) the **residual stream**.\n",
+ "\n",
+ "Everything communicated from earlier layers to later layers must go via the residual stream, so it acts as a \"bottleneck\" in the transformer, essentially capturing everything the model has \"thought\" so far. This means it is often a natural thing to study, since it will contain everything important going on in the model.\n",
+ " \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 49,
+ "referenced_widgets": [
+ "aa7a7c4e96fe4a0fa2e72a2579c37799",
+ "9ce9fa9d715a4dc0a0b5fa2778ac04e3",
+ "2129c74c13df48c894eb3b7b6e4f3f8c",
+ "12d40111a6a34ebc8460956895f1ac20",
+ "273fd4bfca3444d3a8b19d9ee3e96db1",
+ "1a5c2570dea344479476c636954cf2f9",
+ "199409203a804e81865a67b881c820c4",
+ "3df1fe78dd564517ad38bc6e463cb7be",
+ "deb7b5bcbdf44f0dad7a2a648b48a021",
+ "9bc15b80984946d4b75e13040ece4901",
+ "cdfd34c475c44e6babf1ca6ef9ce70ed"
+ ]
+ },
+ "id": "BP2Ju5AnNIzS",
+ "outputId": "ba632780-874b-408b-e306-d9eda436fd35"
+ },
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import hf_hub_download\n",
+ "\n",
+ "path_to_params = hf_hub_download(\n",
+ " repo_id=sae_repo_id,\n",
+ " filename=filename,\n",
+ " force_download=False,\n",
+ ")\n",
+ "\n",
+ "params = np.load(path_to_params)\n",
+ "pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8wy7DSTaRc90"
+ },
+ "source": [
+ "### Implementing the SAE\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "18HRoRagoPWP"
+ },
+ "source": [
+ "We now define the forward pass of the SAE for pedagogical purposes (in practice, we recommend using the implementation in SAELens)\n",
+ "\n",
+ "Gemma Scope is a collection of [JumpReLU SAEs](https://arxiv.org/abs/2407.14435), which is like a standard two layer (one hidden layer) neural network, but where the activation function is a **JumpReLU**: a ReLU with a discontinuous jump."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "id": "WYfvS97fAFzq"
+ },
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "class JumpReLUSAE(nn.Module):\n",
+ " def __init__(self, d_model, d_sae):\n",
+ " # Note that we initialise these to zeros because we're loading in pre-trained weights.\n",
+ " # If you want to train your own SAEs then we recommend using blah\n",
+ " super().__init__()\n",
+ " self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))\n",
+ " self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))\n",
+ " self.threshold = nn.Parameter(torch.zeros(d_sae))\n",
+ " self.b_enc = nn.Parameter(torch.zeros(d_sae))\n",
+ " self.b_dec = nn.Parameter(torch.zeros(d_model))\n",
+ "\n",
+ " def encode(self, input_acts):\n",
+ " pre_acts = input_acts @ self.W_enc + self.b_enc\n",
+ " mask = (pre_acts > self.threshold)\n",
+ " acts = mask * torch.nn.functional.relu(pre_acts)\n",
+ " return acts\n",
+ "\n",
+ " def decode(self, acts):\n",
+ " return acts @ self.W_dec + self.b_dec\n",
+ "\n",
+ " def forward(self, acts):\n",
+ " acts = self.encode(acts)\n",
+ " recon = self.decode(acts)\n",
+ " return recon\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "X91UGkU1cSrC",
+ "outputId": "1de284c2-32d2-434d-8f2c-a57beb23e007"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "JumpReLUSAE()"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])\n",
+ "sae.load_state_dict(pt_params)\n",
+ "sae.cuda()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "spZhppkzjIAf"
+ },
+ "source": [
+ "### Running the SAE on model activatinos\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NrG3P-UNWSNp"
+ },
+ "source": [
+ "Let's first get out some activations from the model at the SAE target site. We'll demonstrate how to do this 'manually' first, by using Pytorch hooks. Note that this is not particularly good practice, and it's probably more practical to use a library like TransformerLens to handle hooking the SAE into a model forward pass. But for illustrative purposes, it's useful to see how it's done.\n",
+ "\n",
+ "We can gather activations at a site by registering a hook. To keep this local, we can wrap this in a function that registers a hook, runs the model, saving the intermediate activation, then removes the hook. (This is basically what TransformerLens is doing under the hood)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "id": "aSvKs581WU7j"
+ },
+ "outputs": [],
+ "source": [
+ "def gather_residual_activations(model, target_layer, inputs):\n",
+ " target_act = None\n",
+ " def gather_target_act_hook(mod, inputs, outputs):\n",
+ " nonlocal target_act # make sure we can modify the target_act from the outer scope\n",
+ " target_act = outputs[0]\n",
+ " return outputs\n",
+ " handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)\n",
+ " _ = model.forward(inputs)\n",
+ " handle.remove()\n",
+ " return target_act"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "dataset_name = \"cornell-movie-review-data/rotten_tomatoes/\"\n",
+ "\n",
+ "splits = {'train': 'train.parquet', 'validation': 'validation.parquet', 'test': 'test.parquet'}\n",
+ "df = pd.read_parquet(f\"hf://datasets/{dataset_name}\" + splits[\"train\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n = len(df)\n",
+ "\n",
+ "sub_df = df.sample(n=n)\n",
+ "\n",
+ "prompts = sub_df[\"text\"].tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'cornell-movie-review-data_rotten_tomatoes__google_gemma-2-9b-it_layer_31_width_16k_average_l0_76_params.npz'"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import os\n",
+ "weight_name = dataset_name + \"/\" + model_name + \"/\" + filename\n",
+ "weight_name = weight_name.replace(os.sep, \"_\")\n",
+ "weight_name"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 8530/8530 [11:18<00:00, 12.57it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "target_acts = []\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for prompt in tqdm(prompts):\n",
+ " inputs = tokenizer.encode(prompt, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ " target_act = gather_residual_activations(model, layer, inputs)\n",
+ " target_acts.append(target_act)\n",
+ " \n",
+ " # Optionally, clear CUDA cache\n",
+ " torch.cuda.empty_cache()\n",
+ "\n",
+ "\n",
+ "# Create a list of tensors\n",
+ "tensor_list = target_acts\n",
+ "\n",
+ "# Convert to NumPy and save\n",
+ "# np.savez(f'{weight_name}.npz', \n",
+ "# *[f'array_{i}' for i in range(len(tensor_list))],\n",
+ "# **{f'array_{i}': tensor.cpu().numpy() for i, tensor in enumerate(tensor_list)})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iS4Re5VTQti5"
+ },
+ "source": [
+ "Now, we can run our SAE on the saved activations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 8530/8530 [00:05<00:00, 1451.02it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "sae_acts = []\n",
+ " \n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for target_act in tqdm(target_acts):\n",
+ " # Move the input to GPU if it's not already there\n",
+ " target_act_gpu = target_act.to(torch.float32).cuda()\n",
+ " \n",
+ " sae_act = sae.encode(target_act_gpu)\n",
+ "\n",
+ " # Move result to CPU and convert to numpy\n",
+ " sae_act_aggregated = ((sae_act[:,:,:] > 0).sum(1) > 0).cpu().numpy()\n",
+ " \n",
+ " # Append the CPU numpy array\n",
+ " sae_acts.append(sae_act_aggregated)\n",
+ " \n",
+ " # Optionally, clear CUDA cache\n",
+ " torch.cuda.empty_cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "59kRU1_Iim3k"
+ },
+ "source": [
+ "Let's just double check that the model looks sensible by checking that we explain a decent chunk of the variance:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Concatenate the list of numpy arrays on the first dimension\n",
+ "array = np.concatenate(sae_acts, axis=0).astype(float)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " ... | \n",
+ " 16375 | \n",
+ " 16376 | \n",
+ " 16377 | \n",
+ " 16378 | \n",
+ " 16379 | \n",
+ " 16380 | \n",
+ " 16381 | \n",
+ " 16382 | \n",
+ " 16383 | \n",
+ " label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 8525 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 8526 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 8527 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 8528 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 8529 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
8530 rows × 16385 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 0 1 2 3 4 5 6 7 8 9 ... 16375 16376 \\\n",
+ "0 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "1 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "2 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "3 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "4 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "... ... ... ... ... ... ... ... ... ... ... ... ... ... \n",
+ "8525 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "8526 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "8527 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "8528 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "8529 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n",
+ "\n",
+ " 16377 16378 16379 16380 16381 16382 16383 label \n",
+ "0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1 \n",
+ "1 0.0 1.0 0.0 1.0 0.0 1.0 1.0 0 \n",
+ "2 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0 \n",
+ "3 0.0 1.0 0.0 1.0 0.0 1.0 1.0 0 \n",
+ "4 1.0 0.0 0.0 1.0 0.0 0.0 1.0 0 \n",
+ "... ... ... ... ... ... ... ... ... \n",
+ "8525 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1 \n",
+ "8526 0.0 0.0 0.0 1.0 0.0 0.0 1.0 1 \n",
+ "8527 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1 \n",
+ "8528 1.0 1.0 0.0 1.0 0.0 0.0 1.0 0 \n",
+ "8529 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0 \n",
+ "\n",
+ "[8530 rows x 16385 columns]"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "result_df = pd.DataFrame(array)\n",
+ "result_df[\"label\"] = sub_df[\"label\"].values\n",
+ "\n",
+ "result_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Accuracy on training: 0.8402637845759297\n",
+ "Classification Report on training:\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.83 0.86 0.84 2713\n",
+ " 1 0.86 0.82 0.84 2746\n",
+ "\n",
+ " accuracy 0.84 5459\n",
+ " macro avg 0.84 0.84 0.84 5459\n",
+ "weighted avg 0.84 0.84 0.84 5459\n",
+ "\n",
+ "Accuracy on validation: 0.8234432234432234\n",
+ "\n",
+ "Classification Report on validation:\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.82 0.84 0.83 692\n",
+ " 1 0.83 0.81 0.82 673\n",
+ "\n",
+ " accuracy 0.82 1365\n",
+ " macro avg 0.82 0.82 0.82 1365\n",
+ "weighted avg 0.82 0.82 0.82 1365\n",
+ "\n",
+ "Non-zero features: [6272, 8410, 11367, 14557, 15837, 12526, 7886, 1518, 13556, 854, 14929, 7796, 15291, 1244, 2442, 14484, 10718, 13507, 264, 8867, 13444, 13545, 6532, 5864]\n",
+ "\n",
+ "Top 20 Most Important Features:\n",
+ " feature importance\n",
+ "6272 6272 0.587859\n",
+ "8410 8410 0.123248\n",
+ "11367 11367 0.092920\n",
+ "14557 14557 0.053496\n",
+ "15837 15837 0.022849\n",
+ "12526 12526 0.018051\n",
+ "7886 7886 0.012444\n",
+ "1518 1518 0.011040\n",
+ "13556 13556 0.010179\n",
+ "854 854 0.009852\n",
+ "14929 14929 0.007890\n",
+ "7796 7796 0.006973\n",
+ "15291 15291 0.005895\n",
+ "1244 1244 0.005341\n",
+ "2442 2442 0.004361\n",
+ "14484 14484 0.004108\n",
+ "10718 10718 0.004002\n",
+ "13507 13507 0.003902\n",
+ "264 264 0.003747\n",
+ "8867 8867 0.003463\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "from sklearn.tree import DecisionTreeClassifier, plot_tree\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.metrics import accuracy_score, classification_report\n",
+ "import matplotlib.pyplot as plt\n",
+ "import requests\n",
+ "\n",
+ "max_depth = 5\n",
+ "\n",
+ "def get_feature_descriptions(feature, model=\"gemma-2-2b\", layer=\"20-gemmascope-res-65k\"):\n",
+ " url = f\"https://www.neuronpedia.org/api/feature/{model}/{layer}/{feature}\"\n",
+ " response = requests.get(url)\n",
+ " output = response.json()[\"explanations\"][0][\"description\"]\n",
+ " return output\n",
+ "\n",
+ "get_feature_descriptions_gemma_2_9b = lambda x: get_feature_descriptions(x, model=\"gemma-2-9b-it\", layer=\"31-gemmascope-res-16k\")\n",
+ "\n",
+ "# Assuming your data is already in a DataFrame called 'result_df'\n",
+ "# If not, load your data into a DataFrame first\n",
+ "\n",
+ "# Separate features and target\n",
+ "X = result_df.drop('label', axis=1)\n",
+ "y = result_df['label']\n",
+ "\n",
+ "# Split the data into training and testing sets\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
+ "\n",
+ "# Split the data into training and validation sets\n",
+ "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)\n",
+ "\n",
+ "# Fit decision tree classifier with constraints\n",
+ "clf = DecisionTreeClassifier(\n",
+ " max_depth=max_depth, # Limit the depth of the tree\n",
+ " random_state=42\n",
+ ")\n",
+ "clf.fit(X_train, y_train)\n",
+ "\n",
+ "# Make predictions\n",
+ "y_train_pred = clf.predict(X_train)\n",
+ "y_val_pred = clf.predict(X_val)\n",
+ "\n",
+ "print(\"Accuracy on training:\", accuracy_score(y_train, y_train_pred))\n",
+ "print(\"Classification Report on training:\")\n",
+ "print(classification_report(y_train, y_train_pred))\n",
+ "\n",
+ "print(\"Accuracy on validation:\", accuracy_score(y_val, y_val_pred))\n",
+ "print(\"\\nClassification Report on validation:\")\n",
+ "print(classification_report(y_val, y_val_pred))\n",
+ "\n",
+ "# Get feature importances\n",
+ "feature_importance = pd.DataFrame({\n",
+ " 'feature': X.columns,\n",
+ " 'importance': clf.feature_importances_\n",
+ "})\n",
+ "\n",
+ "# Sort features by importance\n",
+ "feature_importance = feature_importance.sort_values('importance', ascending=False)\n",
+ "\n",
+ "print(\"Non-zero features:\", feature_importance.loc[feature_importance[\"importance\"] > 0].feature.tolist())\n",
+ "\n",
+ "# Print top 20 most important features\n",
+ "print(\"\\nTop 20 Most Important Features:\")\n",
+ "print(feature_importance.head(20))\n",
+ "\n",
+ "# Get feature descriptions for non-zero importance features\n",
+ "non_zero_features = feature_importance.loc[feature_importance[\"importance\"] > 0, \"feature\"].tolist()\n",
+ "feature_descriptions = {feature: get_feature_descriptions_gemma_2_9b(feature) for feature in non_zero_features}\n",
+ "\n",
+ "# Create a mapping of feature names to their descriptions\n",
+ "feature_names_with_desc = [f\"{feat}\\n{feature_descriptions[feat][:50]}...\" if feat in feature_descriptions else feat for feat in X.columns]\n",
+ "\n",
+ "# # Visualize the decision tree with feature descriptions\n",
+ "# plt.figure(figsize=(30,15))\n",
+ "# plot_tree(clf, feature_names=feature_names_with_desc, class_names=clf.classes_.astype(str), filled=True, rounded=True, max_depth=3)\n",
+ "# plt.savefig('constrained_decision_tree_with_descriptions.png', dpi=300, bbox_inches='tight')\n",
+ "# plt.close()\n",
+ "\n",
+ "# print(\"Constrained decision tree visualization with feature descriptions has been saved as 'constrained_decision_tree_with_descriptions.png'\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Decision Tree model has been exported to decision_tree_max_depth_5_ google_gemma-2-9b-it_layer_31_width_16k_average_l0_76_params.pkl\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pickle\n",
+ "\n",
+ "clf_name = f\"decision_tree_max_depth_{max_depth}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n",
+ "clf_name = clf_name.replace(os.sep, \"_\")\n",
+ "\n",
+ "with open(f'{clf_name}.pkl', 'wb') as model_file:\n",
+ " pickle.dump(clf, model_file)\n",
+ "\n",
+ "print(f\"Decision Tree model has been exported to {clf_name}.pkl\")\n",
+ "\n",
+ "with open(f\"{clf_name}.pkl\", 'rb') as model_file:\n",
+ " clf = pickle.load(model_file)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(f\"{clf_name}.pkl\", 'rb') as model_file:\n",
+ " clf = pickle.load(model_file)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Accuracy on training: 0.8521707272394211\n",
+ "Classification Report on training:\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.84 0.87 0.85 2713\n",
+ " 1 0.87 0.83 0.85 2746\n",
+ "\n",
+ " accuracy 0.85 5459\n",
+ " macro avg 0.85 0.85 0.85 5459\n",
+ "weighted avg 0.85 0.85 0.85 5459\n",
+ "\n",
+ "Accuracy on validation: 0.8652014652014652\n",
+ "\n",
+ "Classification Report on validation:\n",
+ " precision recall f1-score support\n",
+ "\n",
+ " 0 0.85 0.89 0.87 692\n",
+ " 1 0.88 0.84 0.86 673\n",
+ "\n",
+ " accuracy 0.87 1365\n",
+ " macro avg 0.87 0.86 0.87 1365\n",
+ "weighted avg 0.87 0.87 0.87 1365\n",
+ "\n",
+ "Non zero features: [6272, 8410, 14557, 7886, 11367, 13556, 15837, 6634, 4795, 1518, 3456, 7796, 3404, 15142, 4364, 12526, 3628, 920, 12970, 5236, 1631, 1374, 13679, 14218, 10816, 3762]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.linear_model import LogisticRegression\n",
+ "X = result_df.drop('label', axis=1)\n",
+ "y = result_df['label']\n",
+ "\n",
+ "# Split the data into training and testing sets\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
+ "\n",
+ "# Split the data into training and validation sets\n",
+ "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)\n",
+ "\n",
+ "# # Fit logistic regression with L1 regularization\n",
+ "# clf = LogisticRegression(penalty='l1', solver='liblinear', C=0.1, random_state=42)\n",
+ "# clf.fit(X_train, y_train)\n",
+ "\n",
+ "# # Make predictions\n",
+ "# y_pred = clf.predict(X_test)\n",
+ "\n",
+ "C = 0.01\n",
+ "\n",
+ "# scaler = StandardScaler()\n",
+ "# X_train_scaled = scaler.fit_transform(X_train)\n",
+ "# X_val_scaled = scaler.transform(X_val)\n",
+ "\n",
+ "# Fit logistic regression with L1 regularization\n",
+ "clf = LogisticRegression(penalty='l1', solver='liblinear', C=C, random_state=42)\n",
+ "clf.fit(X_train, y_train)\n",
+ "\n",
+ "# Make predictions\n",
+ "y_val_pred = clf.predict(X_val)\n",
+ "\n",
+ "print(\"Accuracy on training:\", accuracy_score(y_train, clf.predict(X_train)))\n",
+ "print(\"Classification Report on training:\")\n",
+ "print(classification_report(y_train, clf.predict(X_train)))\n",
+ "\n",
+ "# Print accuracy and classification report\n",
+ "print(\"Accuracy on validation:\", accuracy_score(y_val, y_val_pred))\n",
+ "print(\"\\nClassification Report on validation:\")\n",
+ "print(classification_report(y_val, y_val_pred))\n",
+ "\n",
+ "# Get feature importances\n",
+ "feature_importance = pd.DataFrame({\n",
+ " 'feature': X.columns,\n",
+ " 'importance': np.abs(clf.coef_[0])\n",
+ "})\n",
+ "\n",
+ "# Sort features by importance\n",
+ "feature_importance = feature_importance.sort_values('importance', ascending=False)\n",
+ "\n",
+ "print(\"Non zero features:\", feature_importance.loc[feature_importance[\"importance\"] > 0].feature.tolist())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MDHhbeMDi4cv"
+ },
+ "source": [
+ "It's always worth checking this sort of thing when you do this by hand to check that you haven't got the wrong site, or are missing a scaling factor or something like this. But here, our results all look like they are supposed to .\n",
+ "\n",
+ "Note that there's a bit of a gotcha here; our SAEs are *NOT* trained on the BOS token, because we found that this tended to be a large outlier and to mess up training. So they tend to give nonsense when we apply to them to it, and we need to be careful not to do this accidentally! We can see this above : the BOS token is a total outlier in terms of L0!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iphXauyzlBUS"
+ },
+ "source": [
+ "Let's look at the highest activating features on this input text, on each token position:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "6272"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
+ "\u001b[1;31mClick here for more info. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
+ }
+ ],
+ "source": [
+ "feature"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[('', 0.0), ('I', 0.0), (' really', 0.0), (' wished', 0.0), (' I', 0.0), (' could', 0.0), (' give', 0.0), (' this', 0.0), (' movie', 0.29309767), (' a', 0.0), (' higher', 0.0), (' rating', 0.0), ('.', 0.0), (' The', 0.21383312), (' plot', 0.0), (' was', 0.22131765), (' interesting', 0.0), (',', 0.0), (' but', 0.51617336), (' the', 0.5799874), (' acting', 0.347309), (' was', 0.36400035), (' terrible', 0.49232012), ('.', 0.7318199), (' The', 0.56170917), (' special', 0.0), (' effects', 0.45976144), (' were', 0.99999994), (' great', 0.0), (',', 0.0), (' but', 0.47706267), (' the', 0.4011524), (' pacing', 0.7848547), (' was', 0.9232518), (' off', 0.4621812), ('.', 0.0), (' The', 0.0), (' movie', 0.59335506), (' was', 0.57606274), (' too', 0.0), (' long', 0.0), (',', 0.0), (' but', 0.4116312), (' the', 0.0), (' ending', 0.5189625), (' was', 0.71944976), (' satisfying', 0.0), ('.', 0.0)]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "orientation": "h",
+ "type": "bar",
+ "x": [
+ -1.2354252922133167,
+ 0.7079076191442241,
+ 0.5767604549015611
+ ],
+ "y": [
+ "phrases indicating inadequate conditions or situations",
+ " expressions of positive recommendations and personal endorsements",
+ "phrases that convey encouragement and recognition of achievements"
+ ]
+ }
+ ],
+ "layout": {
+ "height": 500,
+ "margin": {
+ "l": 200
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Feature contribution"
+ },
+ "xaxis": {
+ "title": {
+ "text": "Contribution"
+ }
+ },
+ "yaxis": {
+ "autorange": "reversed",
+ "title": {
+ "text": "Features"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import gradio as gr\n",
+ "\n",
+ "topk = 3\n",
+ "\n",
+ "examples = [\n",
+ " \"a masterpiece four years in the making .\",\n",
+ " \"a sentimental mess that never rings true .\",\n",
+ " \"the action clichés just pile up .\"\n",
+ "]\n",
+ "\n",
+ "text = \"I really wished I could give this movie a higher rating. The plot was interesting, but the acting was terrible. The special effects were great, but the pacing was off. The movie was too long, but the ending was satisfying.\"\n",
+ "\n",
+ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ "target_act = gather_residual_activations(model, layer, inputs)\n",
+ "sae_act = sae.encode(target_act)\n",
+ "sae_act_aggregated = ((sae_act[:,:,:] > 0).sum(1) > 0).cpu().numpy()\n",
+ "\n",
+ "X = pd.DataFrame(sae_act_aggregated)\n",
+ "\n",
+ "feature_contributions = X.iloc[0].astype(float).values * clf.coef_[0]\n",
+ "\n",
+ "contrib_df = pd.DataFrame({\n",
+ " 'feature': range(len(feature_contributions)),\n",
+ " 'contribution': feature_contributions\n",
+ "})\n",
+ "\n",
+ "contrib_df = contrib_df.loc[contrib_df['contribution'].abs() > 0]\n",
+ "\n",
+ "# Sort by absolute contribution and get top N\n",
+ "contrib_df = contrib_df.reindex(contrib_df['contribution'].abs().sort_values(ascending=False).index)\n",
+ "\n",
+ "contrib_df = contrib_df.head(topk)\n",
+ "contrib_df[\"description\"] = contrib_df[\"feature\"].apply(get_feature_descriptions)\n",
+ "\n",
+ "import plotly.graph_objs as go\n",
+ "\n",
+ "fig = go.Figure(go.Bar(\n",
+ " x=contrib_df['contribution'],\n",
+ " y=contrib_df['description'],\n",
+ " orientation='h' # Horizontal bar chart\n",
+ "))\n",
+ "\n",
+ "fig.update_layout(\n",
+ " title='Feature contribution',\n",
+ " xaxis_title='Contribution',\n",
+ " yaxis_title='Features',\n",
+ " height=500,\n",
+ " margin=dict(l=200) # Increase left margin to accommodate longer feature names\n",
+ ")\n",
+ "fig.update_yaxes(autorange=\"reversed\")\n",
+ "\n",
+ "probability = clf.predict_proba(X)[0]\n",
+ "classes = {\n",
+ " \"Positive\": probability[1],\n",
+ " \"Negative\": probability[0]\n",
+ "}\n",
+ "\n",
+ "choices = [(description, feature) for description, feature in zip(contrib_df[\"description\"], contrib_df[\"feature\"])]\n",
+ "dropdown = gr.Dropdown(choices=choices, \n",
+ " value=choices[0][1],\n",
+ " interactive=True, label=\"Features\")\n",
+ "\n",
+ "feature = choices[0][1]\n",
+ "\n",
+ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ "target_act = gather_residual_activations(model, layer, inputs)\n",
+ "sae_act = sae.encode(target_act)\n",
+ "\n",
+ "activated_tokens = sae_act[0:,:,feature]\n",
+ "max_activation = activated_tokens.max().item()\n",
+ "activated_tokens /= max_activation\n",
+ "\n",
+ "activated_tokens = activated_tokens.cpu().detach().numpy()\n",
+ "\n",
+ "output = []\n",
+ "\n",
+ "for i, token_id in enumerate(inputs[0, :]):\n",
+ " token = tokenizer.decode(token_id)\n",
+ " output.append((token, activated_tokens[0, i]))\n",
+ "\n",
+ "print(output)\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'Positive': 0.3629834319308022, 'Negative': 0.6370165680691978}"
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "classes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[('', 0.0), ('the', 0.0), (' action', 0.0), (' clichés', 0.47497016), (' just', 1.0), (' pile', 0.516835), (' up', 0.46400496), (' .', 0.4915409)]\n"
+ ]
+ }
+ ],
+ "source": [
+ "feature = choices[2][1]\n",
+ "\n",
+ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ "target_act = gather_residual_activations(model, layer, inputs)\n",
+ "sae_act = sae.encode(target_act)\n",
+ "\n",
+ "activated_tokens = sae_act[0:,:,feature]\n",
+ "max_activation = activated_tokens.max().item()\n",
+ "activated_tokens /= max_activation\n",
+ "\n",
+ "activated_tokens = activated_tokens.cpu().detach().numpy()\n",
+ "\n",
+ "output = []\n",
+ "\n",
+ "for i, token_id in enumerate(inputs[0, :]):\n",
+ " token = tokenizer.decode(token_id)\n",
+ " output.append((token, activated_tokens[0, i]))\n",
+ "\n",
+ "print(output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Non zero features: [6272, 8410, 14557, 7886, 11367, 13556, 15837, 6634, 4795, 1518, 3456, 7796, 3404, 15142, 4364, 12526, 3628, 920, 12970, 5236, 1631, 1374, 13679, 14218, 10816, 3762]\n",
+ "\n",
+ "Top Important Features:\n",
+ " feature importance\n",
+ "6272 6272 1.235425\n",
+ "8410 8410 0.707908\n",
+ "14557 14557 0.576760\n",
+ "7886 7886 0.485816\n",
+ "11367 11367 0.467120\n",
+ "13556 13556 0.417031\n",
+ "15837 15837 0.383319\n",
+ "6634 6634 0.354729\n",
+ "4795 4795 0.327832\n",
+ "1518 1518 0.325042\n",
+ "3456 3456 0.193763\n",
+ "7796 7796 0.178672\n",
+ "3404 3404 0.155527\n",
+ "15142 15142 0.123701\n",
+ "4364 4364 0.114390\n",
+ "12526 12526 0.098219\n",
+ "3628 3628 0.084569\n",
+ "920 920 0.056221\n",
+ "12970 12970 0.046524\n",
+ "5236 5236 0.046149\n"
+ ]
+ }
+ ],
+ "source": [
+ "import requests\n",
+ "\n",
+ "def get_feature_descriptions(feature):\n",
+ " layer_name = f\"{layer}-gemmascope-res-{width}\"\n",
+ " model_name_neuronpedia = model_name.split(\"/\")[1]\n",
+ "\n",
+ " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}\"\n",
+ "\n",
+ " response = requests.get(url)\n",
+ " output = response.json()[\"explanations\"][0][\"description\"]\n",
+ " return output\n",
+ "\n",
+ "# Get feature importances\n",
+ "feature_importance = pd.DataFrame({\n",
+ " 'feature': X.columns,\n",
+ " 'importance': np.abs(clf.coef_[0])\n",
+ "})\n",
+ "\n",
+ "# Sort features by importance\n",
+ "feature_importance = feature_importance.sort_values('importance', ascending=False)\n",
+ "feature_importance = feature_importance.loc[feature_importance[\"importance\"] > 0]\n",
+ "\n",
+ "# feature_importance[\"description\"] = feature_importance[\"feature\"].apply(get_feature_descriptions)\n",
+ "\n",
+ "print(\"Non zero features:\", feature_importance.feature.tolist())\n",
+ "\n",
+ "# Print top 20 most important features\n",
+ "print(\"\\nTop Important Features:\")\n",
+ "print(feature_importance.head(20))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Linear classifier model has been exported to linear_classifier_C_0.01_ google_gemma-2-9b-it_layer_31_width_16k_average_l0_76_params.pkl\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pickle\n",
+ "\n",
+ "clf_name = f\"linear_classifier_C_{C}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n",
+ "clf_name = clf_name.replace(os.sep, \"_\")\n",
+ "\n",
+ "with open(f'{clf_name}.pkl', 'wb') as model_file:\n",
+ " pickle.dump(clf, model_file)\n",
+ "\n",
+ "print(f\"Linear classifier model has been exported to {clf_name}.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "# params = {\n",
+ "# \"model_name\" : \"google/gemma-2-2b\",\n",
+ "# \"width\" : \"16k\",\n",
+ "# \"layer\" : 23,\n",
+ "# \"l0\" : 74,\n",
+ "# \"sae_repo_id\": \"google/gemma-scope-2b-pt-res\",\n",
+ "# \"filename\" : \"layer_23/width_16k/average_l0_74/params.npz\"\n",
+ "# }\n",
+ "\n",
+ "params = {\n",
+ " \"model_name\" : \"google/gemma-2-9b-it\",\n",
+ " \"width\" : \"16k\",\n",
+ " \"layer\" : 31,\n",
+ " \"l0\" : 76,\n",
+ " \"sae_repo_id\": \"google/gemma-scope-9b-it-res\",\n",
+ " \"filename\" : \"layer_31/width_16k/average_l0_76/params.npz\"\n",
+ "}\n",
+ "\n",
+ "model_name = params[\"model_name\"]\n",
+ "width = params[\"width\"]\n",
+ "layer = params[\"layer\"]\n",
+ "l0 = params[\"l0\"]\n",
+ "sae_repo_id = params[\"sae_repo_id\"]\n",
+ "filename = params[\"filename\"]\n",
+ "\n",
+ "feature_importance = pd.read_csv(\"feature_importance.csv\")\n",
+ "feature_importance = feature_importance.iloc[:3]\n",
+ "\n",
+ "import requests\n",
+ "\n",
+ "def get_feature_descriptions(feature):\n",
+ " layer_name = f\"{layer}-gemmascope-res-{width}\"\n",
+ " model_name_neuronpedia = model_name.split(\"/\")[1]\n",
+ "\n",
+ " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}\"\n",
+ "\n",
+ " response = requests.get(url)\n",
+ " output = response.json()[\"explanations\"][0][\"description\"]\n",
+ " return output\n",
+ "feature_importance[\"description\"] = feature_importance[\"feature\"].apply(get_feature_descriptions)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "orientation": "h",
+ "type": "bar",
+ "x": [
+ 0.7149210223756529,
+ 0.5306234489651611,
+ 0.3787273657087757
+ ],
+ "y": [
+ "URLs and hyperlinks within the text",
+ " numerical values and statistical data representations",
+ "keywords and identifiers related to programming and networking concepts"
+ ]
+ }
+ ],
+ "layout": {
+ "height": 500,
+ "margin": {
+ "l": 200
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Feature Importance"
+ },
+ "xaxis": {
+ "title": {
+ "text": "Importance"
+ }
+ },
+ "yaxis": {
+ "autorange": "reversed",
+ "title": {
+ "text": "Features"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import plotly.graph_objs as go\n",
+ "\n",
+ "fig = go.Figure(go.Bar(\n",
+ " x=feature_importance['importance'],\n",
+ " y=feature_importance['description'],\n",
+ " orientation='h' # Horizontal bar chart\n",
+ "))\n",
+ "\n",
+ "fig.update_layout(\n",
+ " title='Feature Importance',\n",
+ " xaxis_title='Importance',\n",
+ " yaxis_title='Features',\n",
+ " height=500,\n",
+ " margin=dict(l=200) # Increase left margin to accommodate longer feature names\n",
+ ")\n",
+ "fig.update_yaxes(autorange=\"reversed\")\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ 3946 4438 13920]\n",
+ "Feature: 3946\n",
+ "Coefficient: 0.0\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "Feature: 4438\n",
+ "Coefficient: -1.4645945804608147\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "Feature: 13920\n",
+ "Coefficient: 0.763696937782067\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "topk = 3\n",
+ "topk_features = feature_importance.head(topk).feature.values\n",
+ "\n",
+ "print(topk_features)\n",
+ "\n",
+ "from IPython.display import IFrame\n",
+ "html_template = \"https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300\"\n",
+ "\n",
+ "def get_dashboard_html(sae_release, sae_id, feature_idx=0):\n",
+ " return html_template.format(sae_release, sae_id, feature_idx)\n",
+ "\n",
+ "for feature_idx in topk_features:\n",
+ " print(f\"Feature: {feature_idx}\")\n",
+ " print(f\"Coefficient: {clf.coef_[0][feature_idx]}\")\n",
+ " html = get_dashboard_html(sae_release = \"gemma-2-2b\", sae_id=\"23-gemmascope-res-16k\", feature_idx=feature_idx)\n",
+ " display(IFrame(html, width=1200, height=600))\n",
+ " print(\"\\n\")\n",
+ "\n",
+ "# html = get_dashboard_html(sae_release = \"gemma-2-2b\", sae_id=\"20-gemmascope-res-16k\", feature_idx=10004)\n",
+ "# IFrame(html, width=1200, height=600)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "27d7a9df72e842c58cae402f94fa60a7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "ename": "EOFError",
+ "evalue": "Ran out of input",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mEOFError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_263180/1539694166.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{scaler_name}.pkl\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mscaler_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mscaler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscaler_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mEOFError\u001b[0m: Ran out of input"
+ ]
+ }
+ ],
+ "source": [
+ "import gradio as gr\n",
+ "import os\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "from huggingface_hub import hf_hub_download\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "\n",
+ "torch.set_grad_enabled(False) # avoid blowing up mem\n",
+ "\n",
+ "params = {\n",
+ " \"model_name\" : \"google/gemma-2-2b\",\n",
+ " \"width\" : \"16k\",\n",
+ " \"layer\" : 23,\n",
+ " \"l0\" : 74,\n",
+ " \"sae_repo_id\": \"google/gemma-scope-2b-pt-res\",\n",
+ " \"filename\" : \"layer_23/width_16k/average_l0_74/params.npz\"\n",
+ "}\n",
+ "\n",
+ "model_name = params[\"model_name\"]\n",
+ "width = params[\"width\"]\n",
+ "layer = params[\"layer\"]\n",
+ "l0 = params[\"l0\"]\n",
+ "sae_repo_id = params[\"sae_repo_id\"]\n",
+ "filename = params[\"filename\"]\n",
+ "\n",
+ "C = 0.01\n",
+ "\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " model_name,\n",
+ " device_map='auto',\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "\n",
+ "path_to_params = hf_hub_download(\n",
+ " repo_id=sae_repo_id,\n",
+ " filename=filename,\n",
+ " force_download=False,\n",
+ ")\n",
+ "\n",
+ "params = np.load(path_to_params)\n",
+ "pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}\n",
+ "\n",
+ "import pickle\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "\n",
+ "clf_name = f\"linear_classifier_C_{C}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n",
+ "clf_name = clf_name.replace(os.sep, \"_\")\n",
+ "\n",
+ "scaler_name = f\"scaler_C_{C}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n",
+ "scaler_name = scaler_name.replace(os.sep, \"_\")\n",
+ "\n",
+ "with open(f\"{clf_name}.pkl\", 'rb') as model_file:\n",
+ " clf = pickle.load(model_file)\n",
+ "\n",
+ "with open(f\"{scaler_name}.pkl\", 'rb') as scaler_file:\n",
+ " scaler = pickle.load(scaler_file)\n",
+ "\n",
+ "import torch.nn as nn\n",
+ "class JumpReLUSAE(nn.Module):\n",
+ " def __init__(self, d_model, d_sae):\n",
+ " # Note that we initialise these to zeros because we're loading in pre-trained weights.\n",
+ " # If you want to train your own SAEs then we recommend using blah\n",
+ " super().__init__()\n",
+ " self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))\n",
+ " self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))\n",
+ " self.threshold = nn.Parameter(torch.zeros(d_sae))\n",
+ " self.b_enc = nn.Parameter(torch.zeros(d_sae))\n",
+ " self.b_dec = nn.Parameter(torch.zeros(d_model))\n",
+ "\n",
+ " def encode(self, input_acts):\n",
+ " pre_acts = input_acts @ self.W_enc + self.b_enc\n",
+ " mask = (pre_acts > self.threshold)\n",
+ " acts = mask * torch.nn.functional.relu(pre_acts)\n",
+ " return acts\n",
+ "\n",
+ " def decode(self, acts):\n",
+ " return acts @ self.W_dec + self.b_dec\n",
+ "\n",
+ " def forward(self, acts):\n",
+ " acts = self.encode(acts)\n",
+ " recon = self.decode(acts)\n",
+ " return recon\n",
+ "\n",
+ "sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])\n",
+ "sae.load_state_dict(pt_params)\n",
+ "sae.cuda()\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def gather_residual_activations(model, target_layer, inputs):\n",
+ " target_act = None\n",
+ " def gather_target_act_hook(mod, inputs, outputs):\n",
+ " nonlocal target_act # make sure we can modify the target_act from the outer scope\n",
+ " target_act = outputs[0]\n",
+ " return outputs\n",
+ " handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)\n",
+ " _ = model.forward(inputs)\n",
+ " handle.remove()\n",
+ " return target_act\n",
+ "\n",
+ "import requests\n",
+ "\n",
+ "def get_feature_descriptions(feature):\n",
+ " layer_name = f\"{layer}-gemmascope-res-{width}\"\n",
+ " model_name_neuronpedia = model_name.split(\"/\")[1]\n",
+ "\n",
+ " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}\"\n",
+ "\n",
+ " response = requests.get(url)\n",
+ " output = response.json()[\"explanations\"][0][\"description\"]\n",
+ " return output\n",
+ "\n",
+ "def embed_content(url):\n",
+ " html_content = f\"\"\"\n",
+ " \n",
+ " \n",
+ " \n",
+ " \"\"\"\n",
+ " return html_content\n",
+ "\n",
+ "def dummy_function(*args):\n",
+ " # This is a placeholder function. Replace with your actual logic.\n",
+ " return \"Scores will be displayed here\"\n",
+ "\n",
+ "examples = [\n",
+ " \"a masterpiece four years in the making .\",\n",
+ " \"a sentimental mess that never rings true .\",\n",
+ " \"the action clichés just pile up .\"\n",
+ "]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import gradio as gr\n",
+ "\n",
+ "topk = 5\n",
+ "\n",
+ "def get_features(text):\n",
+ "\n",
+ " inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ " target_act = gather_residual_activations(model, layer, inputs)\n",
+ " sae_act = sae.encode(target_act)\n",
+ " sae_act_aggregated = ((sae_act[:,1:,:] > 0).sum(1) > 0).cpu().numpy()\n",
+ "\n",
+ " X = pd.DataFrame(sae_act_aggregated)\n",
+ "\n",
+ " feature_contributions = X.iloc[0].astype(float).values * clf.coef_[0]\n",
+ "\n",
+ " contrib_df = pd.DataFrame({\n",
+ " 'feature': range(len(feature_contributions)),\n",
+ " 'contribution': feature_contributions\n",
+ " })\n",
+ "\n",
+ " contrib_df = contrib_df.loc[contrib_df['contribution'].abs() > 0]\n",
+ "\n",
+ " # Sort by absolute contribution and get top N\n",
+ " contrib_df = contrib_df.reindex(contrib_df['contribution'].abs().sort_values(ascending=False).index)\n",
+ "\n",
+ " contrib_df = contrib_df.head(topk)\n",
+ " contrib_df[\"description\"] = contrib_df[\"feature\"].apply(get_feature_descriptions)\n",
+ "\n",
+ " import plotly.graph_objs as go\n",
+ "\n",
+ " fig = go.Figure(go.Bar(\n",
+ " x=contrib_df['contribution'],\n",
+ " y=contrib_df['description'],\n",
+ " orientation='h' # Horizontal bar chart\n",
+ " ))\n",
+ "\n",
+ " fig.update_layout(\n",
+ " title='Feature contribution',\n",
+ " xaxis_title='Contribution',\n",
+ " yaxis_title='Features',\n",
+ " height=500,\n",
+ " margin=dict(l=200) # Increase left margin to accommodate longer feature names\n",
+ " )\n",
+ " fig.update_yaxes(autorange=\"reversed\")\n",
+ "\n",
+ " probability = clf.predict_proba(X)[0]\n",
+ " classes = {\n",
+ " \"Positive\": probability[1],\n",
+ " \"Negative\": probability[0]\n",
+ " }\n",
+ "\n",
+ " choices = [(description, feature) for description, feature in zip(contrib_df[\"description\"], contrib_df[\"feature\"])]\n",
+ " dropdown = gr.Dropdown(choices=choices, \n",
+ " value=choices[0][1],\n",
+ " interactive=True, label=\"Features\")\n",
+ "\n",
+ " return classes, fig, dropdown"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "IndexError",
+ "evalue": "index 31 is out of range",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_263180/4219291529.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madd_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mtarget_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgather_residual_activations\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0msae_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_act\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0msae_act_aggregated\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msae_act\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/tmp/ipykernel_263180/2004866324.py\u001b[0m in \u001b[0;36mgather_residual_activations\u001b[0;34m(model, target_layer, inputs)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtarget_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mhandle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtarget_layer\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister_forward_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgather_target_act_hook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mremove\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/lavague/lib/python3.10/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_abs_string_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__setitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/miniconda3/envs/lavague/lib/python3.10/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36m_get_abs_string_index\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mIndexError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'index {idx} is out of range'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mIndexError\u001b[0m: index 31 is out of range"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "a sentimental mess that never rings true .\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "orientation": "h",
+ "type": "bar",
+ "x": [
+ 0.7149210223756529,
+ 0.5306234489651611,
+ 0.3787273657087757
+ ],
+ "y": [
+ "URLs and hyperlinks within the text",
+ " numerical values and statistical data representations",
+ "keywords and identifiers related to programming and networking concepts"
+ ]
+ }
+ ],
+ "layout": {
+ "height": 500,
+ "margin": {
+ "l": 200
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "title": {
+ "text": "Feature Importance"
+ },
+ "xaxis": {
+ "title": {
+ "text": "Importance"
+ }
+ },
+ "yaxis": {
+ "autorange": "reversed",
+ "title": {
+ "text": "Features"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "print(text)\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ "target_act = gather_residual_activations(model, layer, inputs)\n",
+ "sae_act = sae.encode(target_act)\n",
+ "\n",
+ "\n",
+ "activated_tokens = sae_act[0:,:,feature]\n",
+ "# max_activation = activated_tokens.max().item()\n",
+ "# activated_tokens /= max_activation\n",
+ "\n",
+ "# activated_tokens = activated_tokens.cpu().detach().numpy()\n",
+ "\n",
+ "# output = []\n",
+ "\n",
+ "# for i, token_id in enumerate(inputs[0, :]):\n",
+ "# token = tokenizer.decode(token_id)\n",
+ "# output.append((token, activated_tokens[0, i]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_feature_iframe(feature):\n",
+ " layer_name = f\"{layer}-gemmascope-res-{width}\"\n",
+ " model_name_neuronpedia = model_name.split(\"/\")[1]\n",
+ "\n",
+ " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}?embed=true\"\n",
+ " html_content = embed_content(url)\n",
+ " return html_content\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "digital-video documentary about stand-up comedians is a great glimpse into a very different world . 1\n"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[62.0354, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
+ " device='cuda:0')"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ "target_act = gather_residual_activations(model, layer, inputs)\n",
+ "sae_act = sae.encode(target_act)\n",
+ "\n",
+ "activated_tokens = sae_act[0:,:,feature]\n",
+ "activated_tokens\n",
+ "# max_activation = activated_tokens.max().item()\n",
+ "# activated_tokens /= max_activation\n",
+ "\n",
+ "# activated_tokens = activated_tokens.cpu().detach().numpy()\n",
+ "\n",
+ "# output = []\n",
+ "\n",
+ "# for i, token_id in enumerate(inputs[0, :]):\n",
+ "# token = tokenizer.decode(token_id)\n",
+ "# output.append((token, activated_tokens[0, i]))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_highlighted_text(text, feature):\n",
+ "\n",
+ " inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n",
+ "\n",
+ " target_act = gather_residual_activations(model, layer, inputs)\n",
+ " sae_act = sae.encode(target_act)\n",
+ "\n",
+ " activated_tokens = sae_act[0:,1:,feature]\n",
+ " max_activation = activated_tokens.max().item()\n",
+ " activated_tokens /= max_activation\n",
+ "\n",
+ " activated_tokens = activated_tokens.cpu().detach().numpy()\n",
+ "\n",
+ " output = []\n",
+ "\n",
+ " for i, token_id in enumerate(inputs[0, 1:]):\n",
+ " token = tokenizer.decode(token_id)\n",
+ " output.append((token, activated_tokens[0, i]))\n",
+ "\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'Positive': 0.9071025712081094, 'Negative': 0.09289742879189056}"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "0135f3b6c691405ea1d522cad0c66097": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e5dff475026a4f5398a8a05f75a3a264",
+ "placeholder": "",
+ "style": "IPY_MODEL_9a6f02d069a24c2299f3a5cb5889e62d",
+ "value": "generation_config.json: 100%"
+ }
+ },
+ "01b6b8374815487f8f8bdbfda028447d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "01fe915b05374809869338453e18bfac": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "084b6b6ec8f04d2c916bc53f535df8a3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0907314c0f214768b3c605168c7d37aa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6333722f782b421f892d2b02729c6304",
+ "placeholder": "",
+ "style": "IPY_MODEL_f57d69bf095d45998baf5d7bf12accb8",
+ "value": "Loading checkpoint shards: 100%"
+ }
+ },
+ "0e904b86aeaa4fdbafa9eb6b1d5b4808": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0ebfd571b5594d81bdd724242371a0d7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "12d40111a6a34ebc8460956895f1ac20": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9bc15b80984946d4b75e13040ece4901",
+ "placeholder": "",
+ "style": "IPY_MODEL_cdfd34c475c44e6babf1ca6ef9ce70ed",
+ "value": " 302M/302M [00:02<00:00, 138MB/s]"
+ }
+ },
+ "1800de9c72c64f62bd686b69c5841e7b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "199409203a804e81865a67b881c820c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "1a5c2570dea344479476c636954cf2f9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1bfe8b9540484dd2af1c03ffc0f84314": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_80c8f84f8dfd4df08bcc6ea480a20ab6",
+ "max": 4983443424,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_01fe915b05374809869338453e18bfac",
+ "value": 4983443424
+ }
+ },
+ "1dd2cd6a7721463481e2963873699b6f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2129c74c13df48c894eb3b7b6e4f3f8c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_3df1fe78dd564517ad38bc6e463cb7be",
+ "max": 302131416,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_deb7b5bcbdf44f0dad7a2a648b48a021",
+ "value": 302131416
+ }
+ },
+ "21ec235ad645477ab00eed8eed917a74": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "259354987c1b4c30b9e26d7c09df0526": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a9f8444633c746c9b5432e1aca71615f",
+ "placeholder": "",
+ "style": "IPY_MODEL_21ec235ad645477ab00eed8eed917a74",
+ "value": "model-00002-of-00003.safetensors: 100%"
+ }
+ },
+ "2652458b30794693853e1707262f8cd9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2b8513db59bb45999581b9bfddde0c87",
+ "placeholder": "",
+ "style": "IPY_MODEL_e25276a157174f35af4776e77942b1fb",
+ "value": " 168/168 [00:00<00:00, 12.9kB/s]"
+ }
+ },
+ "273fd4bfca3444d3a8b19d9ee3e96db1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2b8513db59bb45999581b9bfddde0c87": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c69de0b32fe41c5b8e9099246a689dd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2c7823a443ae4a1bad237ec683078e7f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "CheckboxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "CheckboxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "CheckboxView",
+ "description": "Add token as git credential?",
+ "description_tooltip": null,
+ "disabled": false,
+ "indent": true,
+ "layout": "IPY_MODEL_8d1aa47245274dc79a25197b5a7de14e",
+ "style": "IPY_MODEL_f3b0f2e26a82474aa9075b3b94630a86",
+ "value": false
+ }
+ },
+ "2cb9454ed503456193b7a446dbc47bcc": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "30e9523940604644b8ba5984a4f2f1d2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "32497c90822b43609fd8847bce049262": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b18fb9f8ccca42f9a05a5bfc2acf9f07",
+ "placeholder": "",
+ "style": "IPY_MODEL_86a6a43b172c4e71ba6399d2e6975323",
+ "value": " 3/3 [04:06<00:00, 75.04s/it]"
+ }
+ },
+ "3c5993356a6f47f6b11a529e201f2353": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ButtonStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "button_color": null,
+ "font_weight": ""
+ }
+ },
+ "3d390ffc96fb42b59161fe758956ad6d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3d5bb02b899e47a5952331aa9e74b53f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ButtonModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ButtonView",
+ "button_style": "",
+ "description": "Login",
+ "disabled": false,
+ "icon": "",
+ "layout": "IPY_MODEL_d7a8b324cab04700bf674475be382f69",
+ "style": "IPY_MODEL_3c5993356a6f47f6b11a529e201f2353",
+ "tooltip": ""
+ }
+ },
+ "3df1fe78dd564517ad38bc6e463cb7be": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4188cc50baeb43f09e6d44d72ee779fa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a0d9e7723bbd4d7aaa32b6714b2feaa9",
+ "placeholder": "",
+ "style": "IPY_MODEL_01b6b8374815487f8f8bdbfda028447d",
+ "value": "model-00003-of-00003.safetensors: 100%"
+ }
+ },
+ "4241d7fa63914a228fa22169620f9711": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d0d64d5c6728438a91b110a522d67dfb",
+ "placeholder": "",
+ "style": "IPY_MODEL_5c351bf198324891858e280a70b7bebd",
+ "value": " 4.98G/4.98G [03:40<00:00, 23.8MB/s]"
+ }
+ },
+ "4b9c7dedc7414d859ea089af36fd11e5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "504406c6dffc4b0489ac6eba9d1e024d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "LabelModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "LabelModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "LabelView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0e904b86aeaa4fdbafa9eb6b1d5b4808",
+ "placeholder": "",
+ "style": "IPY_MODEL_811718d577e04953aec7cb75259536aa",
+ "value": "Token is valid (permission: read)."
+ }
+ },
+ "5319c56843cb4c9aad4087fde3739427": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5418b7eeccb044259def851cddd459a8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "LabelModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "LabelModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "LabelView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9cbe551f05784a60aecef170719d029c",
+ "placeholder": "",
+ "style": "IPY_MODEL_0ebfd571b5594d81bdd724242371a0d7",
+ "value": "Your token has been saved to /root/.cache/huggingface/token"
+ }
+ },
+ "5529e02b668d4af08dae7e74366da305": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "PasswordModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "PasswordModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "PasswordView",
+ "continuous_update": true,
+ "description": "Token:",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_084b6b6ec8f04d2c916bc53f535df8a3",
+ "placeholder": "",
+ "style": "IPY_MODEL_ea19037fd122497492342715ded64099",
+ "value": ""
+ }
+ },
+ "57d74d1926fb41cb87373b7509080ad2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "58f8fb3c7ce848babcdc153df19ced55": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "5c351bf198324891858e280a70b7bebd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6333722f782b421f892d2b02729c6304": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "64239117b4dc4c16b17788ccbe34266f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6c7e7a09370d4b81ba782030575eeaa2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6fc186b9b8fd4a7a9a950173198fea66": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "756004657762404882aa19d3185c10e3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_4188cc50baeb43f09e6d44d72ee779fa",
+ "IPY_MODEL_8d8ad710b1ca420cb338042cb9ce3a9f",
+ "IPY_MODEL_d9051750811a425c9084c57b80ba24b0"
+ ],
+ "layout": "IPY_MODEL_aff6b37ce7df414fad0817514477879b"
+ }
+ },
+ "7605920df5734f689bc9c5816d8133c8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7a7d7b8540bb401f92c6f43e2f4afcae": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "VBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "VBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "VBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_504406c6dffc4b0489ac6eba9d1e024d",
+ "IPY_MODEL_5418b7eeccb044259def851cddd459a8",
+ "IPY_MODEL_9ba3205fae3a48c58d89dbd171ddf4ce"
+ ],
+ "layout": "IPY_MODEL_d0d140a29b5242cb9325458e198f90fa"
+ }
+ },
+ "80c8f84f8dfd4df08bcc6ea480a20ab6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "811718d577e04953aec7cb75259536aa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "836808558f274f009d63781f99e4be25": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_30e9523940604644b8ba5984a4f2f1d2",
+ "max": 3,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_cd516a5df8454429a2628ff0d69ae023",
+ "value": 3
+ }
+ },
+ "85b91daf4f4842058daa3916d2260e71": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "86a6a43b172c4e71ba6399d2e6975323": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "8d1aa47245274dc79a25197b5a7de14e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8d296bdf07a44fdb8528df7f1989708a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0907314c0f214768b3c605168c7d37aa",
+ "IPY_MODEL_99f9bb38dba4417595cfdd040882c3de",
+ "IPY_MODEL_a0f8f64c4d0c47658976fee492692bec"
+ ],
+ "layout": "IPY_MODEL_5319c56843cb4c9aad4087fde3739427"
+ }
+ },
+ "8d8ad710b1ca420cb338042cb9ce3a9f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_57d74d1926fb41cb87373b7509080ad2",
+ "max": 481381384,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_de033c141a5b492ab10cea354e7dda12",
+ "value": 481381384
+ }
+ },
+ "9405fa0476af4bcda2b3e6a9953c5d0e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9675c8033d02430dbca7e386a23ca109": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "97ad61a843c444a496ad54d647d6bc2d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_cccba553cff64b6b9951fb976ceec179",
+ "max": 168,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_2cb9454ed503456193b7a446dbc47bcc",
+ "value": 168
+ }
+ },
+ "99f9bb38dba4417595cfdd040882c3de": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1dd2cd6a7721463481e2963873699b6f",
+ "max": 3,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_58f8fb3c7ce848babcdc153df19ced55",
+ "value": 3
+ }
+ },
+ "9a6f02d069a24c2299f3a5cb5889e62d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9ba3205fae3a48c58d89dbd171ddf4ce": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "LabelModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "LabelModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "LabelView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6fc186b9b8fd4a7a9a950173198fea66",
+ "placeholder": "",
+ "style": "IPY_MODEL_cb9a42f8e7dc44f4bb4cc73809bdf409",
+ "value": "Login successful"
+ }
+ },
+ "9bc15b80984946d4b75e13040ece4901": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9cbe551f05784a60aecef170719d029c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9ce9fa9d715a4dc0a0b5fa2778ac04e3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1a5c2570dea344479476c636954cf2f9",
+ "placeholder": "",
+ "style": "IPY_MODEL_199409203a804e81865a67b881c820c4",
+ "value": "params.npz: 100%"
+ }
+ },
+ "a0d9e7723bbd4d7aaa32b6714b2feaa9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a0f8f64c4d0c47658976fee492692bec": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9405fa0476af4bcda2b3e6a9953c5d0e",
+ "placeholder": "",
+ "style": "IPY_MODEL_1800de9c72c64f62bd686b69c5841e7b",
+ "value": " 3/3 [00:43<00:00, 12.24s/it]"
+ }
+ },
+ "a3ce764762d84004b8bd77c062b86ccf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_259354987c1b4c30b9e26d7c09df0526",
+ "IPY_MODEL_1bfe8b9540484dd2af1c03ffc0f84314",
+ "IPY_MODEL_4241d7fa63914a228fa22169620f9711"
+ ],
+ "layout": "IPY_MODEL_c3c35ce03769471c8be12e9a327bf577"
+ }
+ },
+ "a7a0870cdd1b47ccb48c418a4b02066b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a9f8444633c746c9b5432e1aca71615f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "aa35c0554bdd4065a931ae7e9431c5ce": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "aa7a7c4e96fe4a0fa2e72a2579c37799": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_9ce9fa9d715a4dc0a0b5fa2778ac04e3",
+ "IPY_MODEL_2129c74c13df48c894eb3b7b6e4f3f8c",
+ "IPY_MODEL_12d40111a6a34ebc8460956895f1ac20"
+ ],
+ "layout": "IPY_MODEL_273fd4bfca3444d3a8b19d9ee3e96db1"
+ }
+ },
+ "adf72ecf85804abdbf29ba3b1041c8bf": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "LabelModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "LabelModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "LabelView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_aa35c0554bdd4065a931ae7e9431c5ce",
+ "placeholder": "",
+ "style": "IPY_MODEL_4b9c7dedc7414d859ea089af36fd11e5",
+ "value": "Connecting..."
+ }
+ },
+ "aff6b37ce7df414fad0817514477879b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b1506dbf12a44b3fa30e392afb0dab56": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_64239117b4dc4c16b17788ccbe34266f",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c69de0b32fe41c5b8e9099246a689dd",
+ "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. "
+ }
+ },
+ "b18fb9f8ccca42f9a05a5bfc2acf9f07": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b55d2a6c60854cd5a0ea7e60a3fba879": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_3d390ffc96fb42b59161fe758956ad6d",
+ "placeholder": "",
+ "style": "IPY_MODEL_85b91daf4f4842058daa3916d2260e71",
+ "value": "Downloading shards: 100%"
+ }
+ },
+ "c3c35ce03769471c8be12e9a327bf577": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c65ebd4e47d24587b84540fca07a7978": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cb9a42f8e7dc44f4bb4cc73809bdf409": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "cccba553cff64b6b9951fb976ceec179": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cd516a5df8454429a2628ff0d69ae023": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "cdfd34c475c44e6babf1ca6ef9ce70ed": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "d0340edd6ce54da4bb5588c4b58ccbd4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_b55d2a6c60854cd5a0ea7e60a3fba879",
+ "IPY_MODEL_836808558f274f009d63781f99e4be25",
+ "IPY_MODEL_32497c90822b43609fd8847bce049262"
+ ],
+ "layout": "IPY_MODEL_9675c8033d02430dbca7e386a23ca109"
+ }
+ },
+ "d0d140a29b5242cb9325458e198f90fa": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": "center",
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": "flex",
+ "flex": null,
+ "flex_flow": "column",
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "50%"
+ }
+ },
+ "d0d64d5c6728438a91b110a522d67dfb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d30a23ea65fb4f808236edce817bae51": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a7a0870cdd1b47ccb48c418a4b02066b",
+ "placeholder": "",
+ "style": "IPY_MODEL_7605920df5734f689bc9c5816d8133c8",
+ "value": " Copy a token from your Hugging Face\ntokens page and paste it below. Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file. "
+ }
+ },
+ "d7a8b324cab04700bf674475be382f69": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d9051750811a425c9084c57b80ba24b0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6c7e7a09370d4b81ba782030575eeaa2",
+ "placeholder": "",
+ "style": "IPY_MODEL_df9f4163904644ce9f6a7d76121095e2",
+ "value": " 481M/481M [00:24<00:00, 23.1MB/s]"
+ }
+ },
+ "de033c141a5b492ab10cea354e7dda12": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "deb7b5bcbdf44f0dad7a2a648b48a021": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "df9f4163904644ce9f6a7d76121095e2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "e25276a157174f35af4776e77942b1fb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "e5dff475026a4f5398a8a05f75a3a264": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ea19037fd122497492342715ded64099": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ee4f21abcd3d489d84931e88d3c51b19": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0135f3b6c691405ea1d522cad0c66097",
+ "IPY_MODEL_97ad61a843c444a496ad54d647d6bc2d",
+ "IPY_MODEL_2652458b30794693853e1707262f8cd9"
+ ],
+ "layout": "IPY_MODEL_c65ebd4e47d24587b84540fca07a7978"
+ }
+ },
+ "f3b0f2e26a82474aa9075b3b94630a86": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "f57d69bf095d45998baf5d7bf12accb8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
| |