Prasanna Sridhar commited on
Commit
96f9e24
·
1 Parent(s): aedd89b

Add sample notebook

Browse files
Files changed (2) hide show
  1. app.py +7 -5
  2. notebooks/demo.ipynb +722 -0
app.py CHANGED
@@ -214,7 +214,7 @@ def generate_output_label(text, num_exemplars):
214
 
215
  return out_label
216
 
217
- def preprocess(image, input_prompts = None):
218
  if input_prompts == None:
219
  prompts = { "image": image, "points": []}
220
  else:
@@ -240,9 +240,9 @@ def get_boxes_from_prediction(model_output, text, keywords = ""):
240
  logits = logits[box_mask, :].cpu().numpy()
241
  return boxes, logits
242
 
243
- def predict(model, image, text, prompts, device):
244
  keywords = "" # do not handle this for now
245
- input_image, input_image_exemplar, exemplar = preprocess(image, prompts)
246
 
247
  input_images = input_image.unsqueeze(0).to(device)
248
  input_image_exemplars = input_image_exemplar.unsqueeze(0).to(device)
@@ -285,12 +285,14 @@ if __name__ == '__main__':
285
  model, transform = build_model_and_transforms(args)
286
  model = model.to(device)
287
 
 
 
288
  @spaces.GPU(duration=120)
289
  def count(image, text, prompts, state, device):
290
  if prompts is None:
291
  prompts = {"image": image, "points": []}
292
 
293
- boxes, _ = predict(model, image, text, prompts, device)
294
  count = len(boxes)
295
  output_img = generate_heatmap(image, boxes)
296
 
@@ -321,7 +323,7 @@ if __name__ == '__main__':
321
  def count_main(image, text, prompts, device):
322
  if prompts is None:
323
  prompts = {"image": image, "points": []}
324
- boxes, _ = predict(model, image, text, prompts, device)
325
  count = len(boxes)
326
  output_img = generate_heatmap(image, boxes)
327
  num_exemplars = len(get_box_inputs(prompts["points"]))
 
214
 
215
  return out_label
216
 
217
+ def preprocess(transform, image, input_prompts = None):
218
  if input_prompts == None:
219
  prompts = { "image": image, "points": []}
220
  else:
 
240
  logits = logits[box_mask, :].cpu().numpy()
241
  return boxes, logits
242
 
243
+ def predict(model, transform, image, text, prompts, device):
244
  keywords = "" # do not handle this for now
245
+ input_image, input_image_exemplar, exemplar = preprocess(transform, image, prompts)
246
 
247
  input_images = input_image.unsqueeze(0).to(device)
248
  input_image_exemplars = input_image_exemplar.unsqueeze(0).to(device)
 
285
  model, transform = build_model_and_transforms(args)
286
  model = model.to(device)
287
 
288
+ _predict = lambda image, text, prompts: predict(model, transform, image, text, prompts, device)
289
+
290
  @spaces.GPU(duration=120)
291
  def count(image, text, prompts, state, device):
292
  if prompts is None:
293
  prompts = {"image": image, "points": []}
294
 
295
+ boxes, _ = _predict(image, text, prompts)
296
  count = len(boxes)
297
  output_img = generate_heatmap(image, boxes)
298
 
 
323
  def count_main(image, text, prompts, device):
324
  if prompts is None:
325
  prompts = {"image": image, "points": []}
326
+ boxes, _ = _predict(image, text, prompts)
327
  count = len(boxes)
328
  output_img = generate_heatmap(image, boxes)
329
  num_exemplars = len(get_box_inputs(prompts["points"]))
notebooks/demo.ipynb ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "yxig5CdZuHb9"
7
+ },
8
+ "source": [
9
+ "# CountGD - Multimodela open-world object counting\n",
10
+ "\n"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "9wyM6J2HuHb-"
17
+ },
18
+ "source": [
19
+ "## Setup\n",
20
+ "\n",
21
+ "The following cells will setup the runtime environment with the following\n",
22
+ "\n",
23
+ "- Mount Google Drive\n",
24
+ "- Install dependencies for running the model\n",
25
+ "- Load the model into memory"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {
31
+ "id": "jn061Tl8uHb-"
32
+ },
33
+ "source": [
34
+ "### Mount Google Drive (if running on colab)\n",
35
+ "\n",
36
+ "The following bit of code will mount your Google Drive folder at `/content/drive`, allowing you to process files directly from it as well as store the results alongside it.\n",
37
+ "\n",
38
+ "Once you execute the next cell, you will be requested to share access with the notebook. Please follow the instructions on screen to do so.\n",
39
+ "If you are not running this on colab, you will still be able to use the files available on your environment."
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 1,
45
+ "metadata": {
46
+ "colab": {
47
+ "base_uri": "https://localhost:8080/"
48
+ },
49
+ "collapsed": true,
50
+ "id": "DkSUXqMPuHb-",
51
+ "outputId": "6b82521e-3afd-4545-b13f-8cfea0975d95"
52
+ },
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
59
+ "env: RUNNING_IN_COLAB=True\n"
60
+ ]
61
+ }
62
+ ],
63
+ "source": [
64
+ "# Check if running colab\n",
65
+ "import logging\n",
66
+ "\n",
67
+ "logging.basicConfig(\n",
68
+ " level=logging.INFO,\n",
69
+ " format='%(asctime)s %(levelname)-8s %(name)s %(message)s'\n",
70
+ ")\n",
71
+ "try:\n",
72
+ " import google.colab\n",
73
+ " RUNNING_IN_COLAB = True\n",
74
+ "except:\n",
75
+ " RUNNING_IN_COLAB = False\n",
76
+ "\n",
77
+ "if RUNNING_IN_COLAB:\n",
78
+ " from google.colab import drive\n",
79
+ " drive.mount('/content/drive')\n",
80
+ "\n",
81
+ "from IPython.core.magic import register_cell_magic\n",
82
+ "from IPython import get_ipython\n",
83
+ "@register_cell_magic\n",
84
+ "def skip_if(line, cell):\n",
85
+ " if eval(line):\n",
86
+ " return\n",
87
+ " get_ipython().run_cell(cell)\n",
88
+ "\n",
89
+ "\n",
90
+ "%env RUNNING_IN_COLAB {RUNNING_IN_COLAB}\n"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {
96
+ "id": "kas5YtyluHb_"
97
+ },
98
+ "source": [
99
+ "### Install Dependencies\n",
100
+ "\n",
101
+ "The environment will be setup with the code, models and required dependencies."
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 3,
107
+ "metadata": {
108
+ "colab": {
109
+ "base_uri": "https://localhost:8080/"
110
+ },
111
+ "id": "982Yiv5tuHb_",
112
+ "outputId": "2f570d1a-c6cc-49c3-c336-1d784d33a169"
113
+ },
114
+ "outputs": [
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "Downloading the repository...\n",
120
+ "Branch 'pr/5' set up to track remote branch 'pr/5' from 'origin'.\n",
121
+ "Requirement already satisfied: pip in /usr/local/lib/python3.11/dist-packages (24.3.1)\n",
122
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (75.8.0)\n",
123
+ "Requirement already satisfied: wheel in /usr/local/lib/python3.11/dist-packages (0.45.1)\n",
124
+ "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121\n",
125
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 1)) (1.13.1)\n",
126
+ "Requirement already satisfied: termcolor in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 2)) (2.5.0)\n",
127
+ "Requirement already satisfied: addict in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 3)) (2.4.0)\n",
128
+ "Requirement already satisfied: yapf==0.40.1 in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 4)) (0.40.1)\n",
129
+ "Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 5)) (1.0.13)\n",
130
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 6)) (1.26.4)\n",
131
+ "Requirement already satisfied: opencv-python in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 7)) (4.10.0.84)\n",
132
+ "Requirement already satisfied: pycocotools in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 8)) (2.0.8)\n",
133
+ "Requirement already satisfied: colorlog in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 9)) (6.9.0)\n",
134
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 10)) (75.8.0)\n",
135
+ "Requirement already satisfied: ushlex in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 11)) (0.99.1)\n",
136
+ "Requirement already satisfied: gradio<5,>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 12)) (4.44.1)\n",
137
+ "Requirement already satisfied: gradio-image-prompter in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 13)) (0.1.0)\n",
138
+ "Requirement already satisfied: spaces in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 14)) (0.32.0)\n",
139
+ "Collecting filetype (from -r requirements.txt (line 15))\n",
140
+ " Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)\n",
141
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 16)) (4.67.1)\n",
142
+ "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 18)) (2.5.1+cu121)\n",
143
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 19)) (0.20.1+cu121)\n",
144
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (from -r requirements.txt (line 20)) (4.47.1)\n",
145
+ "Requirement already satisfied: importlib-metadata>=6.6.0 in /usr/local/lib/python3.11/dist-packages (from yapf==0.40.1->-r requirements.txt (line 4)) (8.5.0)\n",
146
+ "Requirement already satisfied: platformdirs>=3.5.1 in /usr/local/lib/python3.11/dist-packages (from yapf==0.40.1->-r requirements.txt (line 4)) (4.3.6)\n",
147
+ "Requirement already satisfied: tomli>=2.0.1 in /usr/local/lib/python3.11/dist-packages (from yapf==0.40.1->-r requirements.txt (line 4)) (2.2.1)\n",
148
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from timm->-r requirements.txt (line 5)) (6.0.2)\n",
149
+ "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.11/dist-packages (from timm->-r requirements.txt (line 5)) (0.27.1)\n",
150
+ "Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from timm->-r requirements.txt (line 5)) (0.5.2)\n",
151
+ "Requirement already satisfied: matplotlib>=2.1.0 in /usr/local/lib/python3.11/dist-packages (from pycocotools->-r requirements.txt (line 8)) (3.10.0)\n",
152
+ "Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (23.2.1)\n",
153
+ "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.7.1)\n",
154
+ "Requirement already satisfied: fastapi<1.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.115.6)\n",
155
+ "Requirement already satisfied: ffmpy in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.5.0)\n",
156
+ "Requirement already satisfied: gradio-client==1.3.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.3.0)\n",
157
+ "Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.28.1)\n",
158
+ "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (6.5.2)\n",
159
+ "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.1.5)\n",
160
+ "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.1.5)\n",
161
+ "Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.10.14)\n",
162
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (24.2)\n",
163
+ "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.2.2)\n",
164
+ "Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (10.4.0)\n",
165
+ "Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.10.5)\n",
166
+ "Requirement already satisfied: pydub in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.25.1)\n",
167
+ "Requirement already satisfied: python-multipart>=0.0.9 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.0.20)\n",
168
+ "Requirement already satisfied: ruff>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.9.2)\n",
169
+ "Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.10.0)\n",
170
+ "Requirement already satisfied: tomlkit==0.12.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.12.0)\n",
171
+ "Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.15.1)\n",
172
+ "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (4.12.2)\n",
173
+ "Requirement already satisfied: urllib3~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.3.0)\n",
174
+ "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.11/dist-packages (from gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.34.0)\n",
175
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.10.0)\n",
176
+ "Requirement already satisfied: websockets<13.0,>=10.0 in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (12.0)\n",
177
+ "Requirement already satisfied: psutil<6,>=2 in /usr/local/lib/python3.11/dist-packages (from spaces->-r requirements.txt (line 14)) (5.9.5)\n",
178
+ "Requirement already satisfied: requests<3.0,>=2.19 in /usr/local/lib/python3.11/dist-packages (from spaces->-r requirements.txt (line 14)) (2.32.3)\n",
179
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (3.16.1)\n",
180
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (3.4.2)\n",
181
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
182
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
183
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
184
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (9.1.0.70)\n",
185
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.3.1)\n",
186
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (11.0.2.54)\n",
187
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (10.3.2.106)\n",
188
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (11.4.5.107)\n",
189
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.0.106)\n",
190
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (2.21.5)\n",
191
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (12.1.105)\n",
192
+ "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (3.1.0)\n",
193
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch->-r requirements.txt (line 18)) (1.13.1)\n",
194
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->-r requirements.txt (line 18)) (12.6.85)\n",
195
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch->-r requirements.txt (line 18)) (1.3.0)\n",
196
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers->-r requirements.txt (line 20)) (2024.11.6)\n",
197
+ "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers->-r requirements.txt (line 20)) (0.21.0)\n",
198
+ "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.10)\n",
199
+ "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.3.1)\n",
200
+ "Requirement already satisfied: starlette<0.42.0,>=0.40.0 in /usr/local/lib/python3.11/dist-packages (from fastapi<1.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.41.3)\n",
201
+ "Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.12.14)\n",
202
+ "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.0.7)\n",
203
+ "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.14.0)\n",
204
+ "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata>=6.6.0->yapf==0.40.1->-r requirements.txt (line 4)) (3.21.0)\n",
205
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (1.3.1)\n",
206
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (0.12.1)\n",
207
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (4.55.3)\n",
208
+ "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (1.4.8)\n",
209
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (3.2.1)\n",
210
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (2.8.2)\n",
211
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.2)\n",
212
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2024.2)\n",
213
+ "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.7.0)\n",
214
+ "Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.27.2)\n",
215
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3.0,>=2.19->spaces->-r requirements.txt (line 14)) (3.4.1)\n",
216
+ "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (8.1.8)\n",
217
+ "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (1.5.4)\n",
218
+ "Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (13.9.4)\n",
219
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib>=2.1.0->pycocotools->-r requirements.txt (line 8)) (1.17.0)\n",
220
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (3.0.0)\n",
221
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (2.18.0)\n",
222
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio<5,>=4.0.0->-r requirements.txt (line 12)) (0.1.2)\n",
223
+ "Downloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)\n",
224
+ "Installing collected packages: filetype\n",
225
+ "Successfully installed filetype-1.2.0\n",
226
+ "inside get_extensions\n",
227
+ "/usr/local/cuda/\n",
228
+ "running build\n",
229
+ "running build_py\n",
230
+ "copying modules/ms_deform_attn.py -> build/lib.linux-x86_64-cpython-311/modules\n",
231
+ "copying modules/__init__.py -> build/lib.linux-x86_64-cpython-311/modules\n",
232
+ "copying functions/__init__.py -> build/lib.linux-x86_64-cpython-311/functions\n",
233
+ "copying functions/ms_deform_attn_func.py -> build/lib.linux-x86_64-cpython-311/functions\n",
234
+ "running build_ext\n",
235
+ "Processing /content/countgd/models/GroundingDINO/ops\n",
236
+ " Preparing metadata (setup.py): started\n",
237
+ " Preparing metadata (setup.py): finished with status 'done'\n",
238
+ "Building wheels for collected packages: MultiScaleDeformableAttention\n",
239
+ " Building wheel for MultiScaleDeformableAttention (setup.py): started\n",
240
+ " Building wheel for MultiScaleDeformableAttention (setup.py): finished with status 'done'\n",
241
+ " Created wheel for MultiScaleDeformableAttention: filename=MultiScaleDeformableAttention-1.0-cp311-cp311-linux_x86_64.whl size=2968787 sha256=98492f743278ba68af0ab7dad6642b06a78f2e0633591b993addc80f682f3157\n",
242
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-fen8yoft/wheels/64/89/f6/145aead02469873626bfd93ef130da365da640c570c39dd3d5\n",
243
+ "Successfully built MultiScaleDeformableAttention\n",
244
+ "Installing collected packages: MultiScaleDeformableAttention\n",
245
+ " Attempting uninstall: MultiScaleDeformableAttention\n",
246
+ " Found existing installation: MultiScaleDeformableAttention 1.0\n",
247
+ " Uninstalling MultiScaleDeformableAttention-1.0:\n",
248
+ " Successfully uninstalled MultiScaleDeformableAttention-1.0\n",
249
+ "Successfully installed MultiScaleDeformableAttention-1.0\n",
250
+ "* True check_forward_equal_with_pytorch_double: max_abs_err 8.67e-19 max_rel_err 2.35e-16\n",
251
+ "* True check_forward_equal_with_pytorch_float: max_abs_err 4.66e-10 max_rel_err 1.13e-07\n",
252
+ "* True check_gradient_numerical(D=30)\n",
253
+ "* True check_gradient_numerical(D=32)\n",
254
+ "* True check_gradient_numerical(D=64)\n",
255
+ "* True check_gradient_numerical(D=71)\n"
256
+ ]
257
+ },
258
+ {
259
+ "name": "stderr",
260
+ "output_type": "stream",
261
+ "text": [
262
+ "+ '[' True == True ']'\n",
263
+ "+ echo 'Downloading the repository...'\n",
264
+ "+ '[' '!' -d /content/countgd ']'\n",
265
+ "+ cd /content/countgd\n",
266
+ "+ git fetch origin refs/pr/5:refs/remotes/origin/pr/5\n",
267
+ "From https://huggingface.co/spaces/nikigoli/countgd\n",
268
+ " * [new ref] refs/pr/5 -> origin/pr/5\n",
269
+ "+ git checkout pr/5\n",
270
+ "Switched to a new branch 'pr/5'\n",
271
+ "+ pip install --upgrade pip setuptools wheel\n",
272
+ "+ pip install -r requirements.txt\n",
273
+ "+ export CUDA_HOME=/usr/local/cuda/\n",
274
+ "+ CUDA_HOME=/usr/local/cuda/\n",
275
+ "+ cd models/GroundingDINO/ops\n",
276
+ "+ python3 setup.py build\n",
277
+ "/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:497: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.\n",
278
+ " warnings.warn(msg.format('we could not find ninja.'))\n",
279
+ "/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:416: UserWarning: The detected CUDA version (12.2) has a minor version mismatch with the version that was used to compile PyTorch (12.1). Most likely this shouldn't be a problem.\n",
280
+ " warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))\n",
281
+ "/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:426: UserWarning: There are no x86_64-linux-gnu-g++ version bounds defined for CUDA version 12.2\n",
282
+ " warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')\n",
283
+ "+ pip install .\n",
284
+ "+ python3 test.py\n"
285
+ ]
286
+ }
287
+ ],
288
+ "source": [
289
+ "%%bash\n",
290
+ "\n",
291
+ "set -euxo pipefail\n",
292
+ "\n",
293
+ "if [ \"${RUNNING_IN_COLAB}\" == \"True\" ]; then\n",
294
+ " echo \"Downloading the repository...\"\n",
295
+ " if [ ! -d /content/countgd ]; then\n",
296
+ " git clone \"https://huggingface.co/spaces/nikigoli/countgd\" /content/countgd\n",
297
+ " fi\n",
298
+ " cd /content/countgd\n",
299
+ " git fetch origin refs/pr/5:refs/remotes/origin/pr/5\n",
300
+ " git checkout pr/5\n",
301
+ "else\n",
302
+ " # TODO check if cwd is the correct git repo\n",
303
+ " # If users use vscode, then we set the default start directory to root of the repo\n",
304
+ " echo \"Running in $(pwd)\"\n",
305
+ "fi\n",
306
+ "\n",
307
+ "# TODO check for gcc-11 or above\n",
308
+ "\n",
309
+ "# Install pip packages\n",
310
+ "pip install --upgrade pip setuptools wheel\n",
311
+ "pip install -r requirements.txt\n",
312
+ "\n",
313
+ "# Compile modules\n",
314
+ "export CUDA_HOME=/usr/local/cuda/\n",
315
+ "cd models/GroundingDINO/ops\n",
316
+ "python3 setup.py build\n",
317
+ "pip install .\n",
318
+ "python3 test.py"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 4,
324
+ "metadata": {
325
+ "colab": {
326
+ "base_uri": "https://localhost:8080/"
327
+ },
328
+ "id": "58iD_HGnvcRJ",
329
+ "outputId": "fe356a68-dced-4f6f-93cc-d83da2f84e28"
330
+ },
331
+ "outputs": [
332
+ {
333
+ "name": "stdout",
334
+ "output_type": "stream",
335
+ "text": [
336
+ "/content/countgd\n"
337
+ ]
338
+ }
339
+ ],
340
+ "source": [
341
+ "%cd {\"/content/countgd\" if RUNNING_IN_COLAB else '.'}"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "markdown",
346
+ "metadata": {
347
+ "id": "gH7A8zthuHb_"
348
+ },
349
+ "source": [
350
+ "## Inference"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "markdown",
355
+ "metadata": {
356
+ "id": "IspbBV0XuHb_"
357
+ },
358
+ "source": [
359
+ "### Loading the model"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 11,
365
+ "metadata": {
366
+ "colab": {
367
+ "base_uri": "https://localhost:8080/"
368
+ },
369
+ "id": "5nBT_HCUuHb_",
370
+ "outputId": "95ceb6c6-bee8-4921-8bff-d28937045f78"
371
+ },
372
+ "outputs": [
373
+ {
374
+ "name": "stderr",
375
+ "output_type": "stream",
376
+ "text": [
377
+ "Some weights of BertModel were not initialized from the model checkpoint at checkpoints/bert-base-uncased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n",
378
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
379
+ ]
380
+ },
381
+ {
382
+ "name": "stdout",
383
+ "output_type": "stream",
384
+ "text": [
385
+ "final text_encoder_type: checkpoints/bert-base-uncased\n",
386
+ "load tokenizer done.\n",
387
+ "final text_encoder_type: checkpoints/bert-base-uncased\n",
388
+ "load tokenizer done.\n"
389
+ ]
390
+ }
391
+ ],
392
+ "source": [
393
+ "import app\n",
394
+ "import importlib\n",
395
+ "importlib.reload(app)\n",
396
+ "from app import (\n",
397
+ " build_model_and_transforms,\n",
398
+ " get_device,\n",
399
+ " get_args_parser,\n",
400
+ " generate_heatmap,\n",
401
+ " predict,\n",
402
+ ")\n",
403
+ "args = get_args_parser().parse_args([])\n",
404
+ "device = get_device()\n",
405
+ "model, transform = build_model_and_transforms(args)\n",
406
+ "model = model.to(device)\n",
407
+ "\n",
408
+ "run = lambda image, text: predict(model, transform, image, text, None, device)\n",
409
+ "get_output = lambda image, boxes: (len(boxes), generate_heatmap(image, boxes))\n"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "markdown",
414
+ "metadata": {
415
+ "id": "gfjraK3vuHb_"
416
+ },
417
+ "source": [
418
+ "### Input / Output Utils\n",
419
+ "\n",
420
+ "Helper functions for reading / writing to zipfiles and csv"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 17,
426
+ "metadata": {
427
+ "id": "qg0g5B-fuHb_"
428
+ },
429
+ "outputs": [],
430
+ "source": [
431
+ "import io\n",
432
+ "import csv\n",
433
+ "from pathlib import Path\n",
434
+ "from contextlib import contextmanager\n",
435
+ "import zipfile\n",
436
+ "import filetype\n",
437
+ "from PIL import Image\n",
438
+ "logger = logging.getLogger()\n",
439
+ "\n",
440
+ "def images_from_zipfile(p: Path):\n",
441
+ " if not zipfile.is_zipfile(p):\n",
442
+ " raise ValueError(f'{p} is not a zipfile!')\n",
443
+ "\n",
444
+ " with zipfile.ZipFile(p, 'r') as zipf:\n",
445
+ " def process_entry(info: zipfile.ZipInfo):\n",
446
+ " with zipf.open(info) as f:\n",
447
+ " if not filetype.is_image(f):\n",
448
+ " logger.debug(f'Skipping file - {info.filename} as it is not an image')\n",
449
+ " return\n",
450
+ " # Try loading the file\n",
451
+ " try:\n",
452
+ " with Image.open(f) as im:\n",
453
+ " im.load()\n",
454
+ " return (info.filename, im)\n",
455
+ " except:\n",
456
+ " logger.exception(f'Error reading file {info.filename}')\n",
457
+ "\n",
458
+ " num_files = sum(1 for info in zipf.infolist() if info.is_dir() == False)\n",
459
+ " logger.info(f'Found {num_files} file(s) in the zip')\n",
460
+ " yield from (process_entry(info) for info in zipf.infolist() if info.is_dir() == False)\n",
461
+ "\n",
462
+ "@contextmanager\n",
463
+ "def zipfile_writer(p: Path):\n",
464
+ " with zipfile.ZipFile(p, 'w') as zipf:\n",
465
+ " def write_output(image, image_filename):\n",
466
+ " buf = io.BytesIO()\n",
467
+ " image.save(buf, 'PNG')\n",
468
+ " zipf.writestr(image_filename, buf.getvalue())\n",
469
+ " yield write_output\n",
470
+ "\n",
471
+ "@contextmanager\n",
472
+ "def csvfile_writer(p: Path):\n",
473
+ " with p.open('w', newline='') as csvfile:\n",
474
+ " fieldnames = ['filename', 'count']\n",
475
+ " csv_writer = csv.DictWriter(csvfile, fieldnames = fieldnames)\n",
476
+ " csv_writer.writeheader()\n",
477
+ "\n",
478
+ " yield csv_writer.writerow"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": 15,
484
+ "metadata": {
485
+ "id": "rFXRk-_uuHb_"
486
+ },
487
+ "outputs": [],
488
+ "source": [
489
+ "from tqdm import tqdm\n",
490
+ "import os\n",
491
+ "def process_zipfile(input_zipfile: Path, text: str):\n",
492
+ " if not input_zipfile.exists() or not input_zipfile.is_file() or not os.access(input_zipfile, os.R_OK):\n",
493
+ " logger.error(f'Cannot open / read zipfile: {input_zipfile}. Please check if it exists')\n",
494
+ " return\n",
495
+ "\n",
496
+ " if text == \"\":\n",
497
+ " logger.error('Please provide the object you would like to count')\n",
498
+ " return\n",
499
+ "\n",
500
+ " output_zipfile = input_zipfile.parent / f'{input_zipfile.stem}_countgd.zip'\n",
501
+ " output_csvfile = input_zipfile.parent / f'{input_zipfile.stem}.csv'\n",
502
+ "\n",
503
+ " logger.info(f'Writing outputs to {output_zipfile.name} and {output_csvfile.name} in {input_zipfile.parent} folder')\n",
504
+ " with zipfile_writer(output_zipfile) as add_to_zip, csvfile_writer(output_csvfile) as write_row:\n",
505
+ " for filename, im in tqdm(images_from_zipfile(input_zipfile)):\n",
506
+ " boxes, _ = run(im, text)\n",
507
+ " count, heatmap = get_output(im, boxes)\n",
508
+ " write_row({'filename': filename, 'count': count})\n",
509
+ " add_to_zip(heatmap, filename)"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {
515
+ "id": "TmqsSxrsuHb_"
516
+ },
517
+ "source": [
518
+ "### Run\n",
519
+ "\n",
520
+ "Use the form on colab to set the parameters, providing the zipfile with input images and a promt text representing the object you want to count.\n",
521
+ "\n",
522
+ "If you are not running on colab, change the values in the next cell\n",
523
+ "\n",
524
+ "Make sure to run the cell once you change the value."
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": 8,
530
+ "metadata": {
531
+ "id": "ZaN918EkuHb_"
532
+ },
533
+ "outputs": [],
534
+ "source": [
535
+ "# @title ## Parameters { display-mode: \"form\", run: \"auto\" }\n",
536
+ "# @markdown Set the following options to pass to the CountGD Model\n",
537
+ "\n",
538
+ "# @markdown ---\n",
539
+ "# @markdown ### Enter a file path to a zip:\n",
540
+ "zipfile_path = \"test_images.zip\" # @param {type:\"string\"}\n",
541
+ "# @markdown\n",
542
+ "# @markdown ### Which object would you like to count?\n",
543
+ "prompt = \"strawberry\" # @param {type:\"string\"}\n",
544
+ "# @markdown ---"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": 18,
550
+ "metadata": {
551
+ "colab": {
552
+ "base_uri": "https://localhost:8080/",
553
+ "height": 66,
554
+ "referenced_widgets": [
555
+ "b14c910dd2594285bb4ad4740099e70c",
556
+ "01631442369e43138c2c5c4a9fe38ceb",
557
+ "ff84907ef88a431bab4bd3d1567cc42a"
558
+ ]
559
+ },
560
+ "id": "fd-ShBCsuHb_",
561
+ "outputId": "5b36bb90-ac6e-46fe-a853-ff11d43dd9f6"
562
+ },
563
+ "outputs": [
564
+ {
565
+ "data": {
566
+ "application/vnd.jupyter.widget-view+json": {
567
+ "model_id": "b14c910dd2594285bb4ad4740099e70c",
568
+ "version_major": 2,
569
+ "version_minor": 0
570
+ },
571
+ "text/plain": [
572
+ "Button(description='Run', style=ButtonStyle())"
573
+ ]
574
+ },
575
+ "metadata": {},
576
+ "output_type": "display_data"
577
+ },
578
+ {
579
+ "name": "stderr",
580
+ "output_type": "stream",
581
+ "text": [
582
+ "11it [00:12, 1.14s/it]\n"
583
+ ]
584
+ }
585
+ ],
586
+ "source": [
587
+ "import ipywidgets as widgets\n",
588
+ "from IPython.display import display\n",
589
+ "button = widgets.Button(description=\"Run\")\n",
590
+ "\n",
591
+ "def on_button_clicked(b):\n",
592
+ " # Display the message within the output widget.\n",
593
+ " process_zipfile(Path(zipfile_path), prompt)\n",
594
+ "\n",
595
+ "button.on_click(on_button_clicked)\n",
596
+ "display(button)"
597
+ ]
598
+ }
599
+ ],
600
+ "metadata": {
601
+ "accelerator": "GPU",
602
+ "colab": {
603
+ "collapsed_sections": [
604
+ "gfjraK3vuHb_"
605
+ ],
606
+ "gpuType": "T4",
607
+ "provenance": []
608
+ },
609
+ "kernelspec": {
610
+ "display_name": "Python 3",
611
+ "name": "python3"
612
+ },
613
+ "language_info": {
614
+ "codemirror_mode": {
615
+ "name": "ipython",
616
+ "version": 3
617
+ },
618
+ "file_extension": ".py",
619
+ "mimetype": "text/x-python",
620
+ "name": "python",
621
+ "nbconvert_exporter": "python",
622
+ "pygments_lexer": "ipython3",
623
+ "version": "3.12.7"
624
+ },
625
+ "widgets": {
626
+ "application/vnd.jupyter.widget-state+json": {
627
+ "01631442369e43138c2c5c4a9fe38ceb": {
628
+ "model_module": "@jupyter-widgets/base",
629
+ "model_module_version": "1.2.0",
630
+ "model_name": "LayoutModel",
631
+ "state": {
632
+ "_model_module": "@jupyter-widgets/base",
633
+ "_model_module_version": "1.2.0",
634
+ "_model_name": "LayoutModel",
635
+ "_view_count": null,
636
+ "_view_module": "@jupyter-widgets/base",
637
+ "_view_module_version": "1.2.0",
638
+ "_view_name": "LayoutView",
639
+ "align_content": null,
640
+ "align_items": null,
641
+ "align_self": null,
642
+ "border": null,
643
+ "bottom": null,
644
+ "display": null,
645
+ "flex": null,
646
+ "flex_flow": null,
647
+ "grid_area": null,
648
+ "grid_auto_columns": null,
649
+ "grid_auto_flow": null,
650
+ "grid_auto_rows": null,
651
+ "grid_column": null,
652
+ "grid_gap": null,
653
+ "grid_row": null,
654
+ "grid_template_areas": null,
655
+ "grid_template_columns": null,
656
+ "grid_template_rows": null,
657
+ "height": null,
658
+ "justify_content": null,
659
+ "justify_items": null,
660
+ "left": null,
661
+ "margin": null,
662
+ "max_height": null,
663
+ "max_width": null,
664
+ "min_height": null,
665
+ "min_width": null,
666
+ "object_fit": null,
667
+ "object_position": null,
668
+ "order": null,
669
+ "overflow": null,
670
+ "overflow_x": null,
671
+ "overflow_y": null,
672
+ "padding": null,
673
+ "right": null,
674
+ "top": null,
675
+ "visibility": null,
676
+ "width": null
677
+ }
678
+ },
679
+ "b14c910dd2594285bb4ad4740099e70c": {
680
+ "model_module": "@jupyter-widgets/controls",
681
+ "model_module_version": "1.5.0",
682
+ "model_name": "ButtonModel",
683
+ "state": {
684
+ "_dom_classes": [],
685
+ "_model_module": "@jupyter-widgets/controls",
686
+ "_model_module_version": "1.5.0",
687
+ "_model_name": "ButtonModel",
688
+ "_view_count": null,
689
+ "_view_module": "@jupyter-widgets/controls",
690
+ "_view_module_version": "1.5.0",
691
+ "_view_name": "ButtonView",
692
+ "button_style": "",
693
+ "description": "Run",
694
+ "disabled": false,
695
+ "icon": "",
696
+ "layout": "IPY_MODEL_01631442369e43138c2c5c4a9fe38ceb",
697
+ "style": "IPY_MODEL_ff84907ef88a431bab4bd3d1567cc42a",
698
+ "tooltip": ""
699
+ }
700
+ },
701
+ "ff84907ef88a431bab4bd3d1567cc42a": {
702
+ "model_module": "@jupyter-widgets/controls",
703
+ "model_module_version": "1.5.0",
704
+ "model_name": "ButtonStyleModel",
705
+ "state": {
706
+ "_model_module": "@jupyter-widgets/controls",
707
+ "_model_module_version": "1.5.0",
708
+ "_model_name": "ButtonStyleModel",
709
+ "_view_count": null,
710
+ "_view_module": "@jupyter-widgets/base",
711
+ "_view_module_version": "1.2.0",
712
+ "_view_name": "StyleView",
713
+ "button_color": null,
714
+ "font_weight": ""
715
+ }
716
+ }
717
+ }
718
+ }
719
+ },
720
+ "nbformat": 4,
721
+ "nbformat_minor": 0
722
+ }