akshayka commited on
Commit
f07bc5e
·
verified ·
1 Parent(s): 59149bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -138
app.py CHANGED
@@ -1,180 +1,324 @@
1
  # /// script
2
- # requires-python = ">=3.10"
3
  # dependencies = [
4
  # "marimo",
5
- # "numba==0.60.0",
6
- # "numpy==2.0.2",
7
- # "scikit-image==0.24.0",
8
  # ]
9
  # ///
10
 
11
  import marimo
12
 
13
- __generated_with = "0.9.6"
14
- app = marimo.App(width="medium")
 
 
 
 
 
 
15
 
16
 
17
  @app.cell(hide_code=True)
18
- def __(mo):
19
  mo.md(
20
- """
21
- # Seam Carving
22
-
23
- _Example adapted from work by [Vincent Warmerdam](https://x.com/fishnets88)_.
24
 
25
- ## The seam carving algorithm
26
- This marimo demonstration is partially an homage to [a great video by Grant
27
- Sanderson](https://www.youtube.com/watch?v=rpB6zQNsbQU) of 3Blue1Brown, which demonstrates
28
- the seam carving algorithm in [Pluto.jl](https://plutojl.org/):
29
 
30
- <iframe width="560" height="315" src="https://www.youtube.com/embed/rpB6zQNsbQU?si=oiZclGIj2atJR47m" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>
 
 
 
31
 
32
- As Grant explains, the seam carving algorithm preserves the shapes of the main content in the image, while killing the "dead space": the image is resized, but the clocks and other content are not resized or deformed.
33
 
34
- This notebook is a Python version of the seam carving algorithm, but it is also a
35
- demonstration of marimo's [caching
36
- feature](https://docs.marimo.io/guides/best_practices/performance.html#cache-computations-with-mo-cache),
37
- which is helpful because the algorithm is compute intensive even when you
38
- use [Numba](https://numba.pydata.org/).
39
 
40
- Try it out by playing with the slider!
41
  """
42
  )
43
  return
44
 
45
 
46
  @app.cell(hide_code=True)
47
- def __():
48
- input_image = "https://upload.wikimedia.org/wikipedia/en/d/dd/The_Persistence_of_Memory.jpg"
 
 
49
 
50
- return input_image,
 
 
 
51
 
52
 
53
  @app.cell(hide_code=True)
54
- def __(mo):
55
- mo.md("""## Try it!""")
56
  return
57
 
58
 
59
- @app.cell
60
- def __():
61
- import marimo as mo
 
62
 
63
- slider = mo.ui.slider(
64
- 0.7,
65
- 1.0,
66
- step=0.05,
67
- value=1.0,
68
- label="Amount of resizing to perform:",
69
- show_value=True,
70
- )
71
- slider
72
- return mo, slider
73
 
74
 
75
  @app.cell
76
- def __(efficient_seam_carve, input_image, mo, slider):
77
- scale_factor = slider.value
78
- result = efficient_seam_carve(input_image, scale_factor)
 
 
 
 
79
 
80
- mo.hstack([mo.image(input_image), mo.image(result)], justify="start")
81
- return result, scale_factor
 
 
 
 
 
 
 
82
 
83
 
84
  @app.cell
85
- def __(mo):
86
  import numpy as np
87
- from numba import jit
88
- from skimage import io, filters, transform
89
- import time
90
-
91
-
92
- def rgb2gray(rgb):
93
- return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
94
-
95
-
96
- def compute_energy_map(gray):
97
- return np.abs(filters.sobel_h(gray)) + np.abs(filters.sobel_v(gray))
98
-
99
-
100
- @jit(nopython=True)
101
- def find_seam(energy_map):
102
- height, width = energy_map.shape
103
- dp = energy_map.copy()
104
- backtrack = np.zeros((height, width), dtype=np.int32)
105
-
106
- for i in range(1, height):
107
- for j in range(width):
108
- if j == 0:
109
- idx = np.argmin(dp[i - 1, j : j + 2])
110
- backtrack[i, j] = idx + j
111
- min_energy = dp[i - 1, idx + j]
112
- elif j == width - 1:
113
- idx = np.argmin(dp[i - 1, j - 1 : j + 1])
114
- backtrack[i, j] = idx + j - 1
115
- min_energy = dp[i - 1, idx + j - 1]
116
- else:
117
- idx = np.argmin(dp[i - 1, j - 1 : j + 2])
118
- backtrack[i, j] = idx + j - 1
119
- min_energy = dp[i - 1, idx + j - 1]
120
-
121
- dp[i, j] += min_energy
122
-
123
- return backtrack
124
-
125
-
126
- @jit(nopython=True)
127
- def remove_seam(image, backtrack):
128
- height, width, _ = image.shape
129
- output = np.zeros((height, width - 1, 3), dtype=np.uint8)
130
- j = np.argmin(backtrack[-1])
131
-
132
- for i in range(height - 1, -1, -1):
133
- for k in range(3):
134
- output[i, :, k] = np.delete(image[i, :, k], j)
135
- j = backtrack[i, j]
136
-
137
- return output
138
-
139
-
140
- def seam_carving(image, new_width):
141
- height, width, _ = image.shape
142
-
143
- while width > new_width:
144
- gray = rgb2gray(image)
145
- energy_map = compute_energy_map(gray)
146
- backtrack = find_seam(energy_map)
147
- image = remove_seam(image, backtrack)
148
- width -= 1
149
-
150
- return image
151
-
152
- @mo.cache
153
- def efficient_seam_carve(image_path, scale_factor):
154
- img = io.imread(image_path)
155
- new_width = int(img.shape[1] * scale_factor)
156
-
157
- start_time = time.time()
158
- carved_img = seam_carving(img, new_width)
159
- end_time = time.time()
160
-
161
- print(f"Seam carving completed in {end_time - start_time:.2f} seconds")
162
-
163
- return carved_img
164
- return (
165
- compute_energy_map,
166
- efficient_seam_carve,
167
- filters,
168
- find_seam,
169
- io,
170
- jit,
171
- np,
172
- remove_seam,
173
- rgb2gray,
174
- seam_carving,
175
- time,
176
- transform,
177
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
  if __name__ == "__main__":
 
1
  # /// script
2
+ # requires-python = ">=3.13"
3
  # dependencies = [
4
  # "marimo",
5
+ # "matplotlib==3.10.1",
6
+ # "numpy==2.2.3",
 
7
  # ]
8
  # ///
9
 
10
  import marimo
11
 
12
+ __generated_with = "0.11.20"
13
+ app = marimo.App()
14
+
15
+
16
+ @app.cell
17
+ def _():
18
+ import marimo as mo
19
+ return (mo,)
20
 
21
 
22
  @app.cell(hide_code=True)
23
+ def _(mo):
24
  mo.md(
25
+ r"""
26
+ # Finding $\pi$ in colliding blocks
 
 
27
 
28
+ One of the remarkable things about mathematical constants like $\pi$ is how frequently they arise in nature, in the most surprising of places.
 
 
 
29
 
30
+ Inspired by 3Blue1Brown, this marimo notebook shows how the number of collisions incurred in a particular system involving two blocks converges to the digits in $\pi$.
31
+ """
32
+ )
33
+ return
34
 
 
35
 
36
+ @app.cell(hide_code=True)
37
+ def _(mo):
38
+ mo.md(
39
+ r"""
40
+ ## The 3Blue1Brown video
41
 
42
+ If you haven't seen it, definitely check out the video that inspired this notebook!
43
  """
44
  )
45
  return
46
 
47
 
48
  @app.cell(hide_code=True)
49
+ def _(mo):
50
+ mo.Html('<iframe width="700" height="400" src="https://www.youtube.com/embed/6dTyOl1fmDo?si=xl9v6Y8x2e3r3A9I" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>')
51
+ return
52
+
53
 
54
+ @app.cell(hide_code=True)
55
+ def _(mo):
56
+ slider = mo.ui.slider(start=0, stop=4, value=0, show_value=True)
57
+ return (slider,)
58
 
59
 
60
  @app.cell(hide_code=True)
61
+ def _(mo, slider):
62
+ mo.md(f"Use this slider to control the weight of the heavier block: {slider}")
63
  return
64
 
65
 
66
+ @app.cell(hide_code=True)
67
+ def _(mo, slider):
68
+ mo.md(rf"The heavier block weighs **$100^{{ {slider.value} }}$** kg.")
69
+ return
70
 
71
+
72
+ @app.cell(hide_code=True)
73
+ def _(mo):
74
+ run_button = mo.ui.run_button(label="Run simulation!")
75
+ run_button.right()
76
+ return (run_button,)
 
 
 
 
77
 
78
 
79
  @app.cell
80
+ def _(run_button, simulate_collisions, slider):
81
+ if run_button.value:
82
+ mass_ratio = 100**slider.value
83
+ _, ani, collisions = simulate_collisions(
84
+ mass_ratio, total_time=15, dt=0.001
85
+ )
86
+ return ani, collisions, mass_ratio
87
 
88
+
89
+ @app.cell
90
+ def _(ani, mo, run_button):
91
+ video = None
92
+ if run_button.value:
93
+ with mo.status.spinner(title="Rendering collision video ..."):
94
+ video = mo.Html(ani.to_html5_video())
95
+ video
96
+ return (video,)
97
 
98
 
99
  @app.cell
100
+ def _():
101
  import numpy as np
102
+ import matplotlib.pyplot as plt
103
+ import matplotlib.animation as animation
104
+ from matplotlib.patches import Rectangle
105
+ return Rectangle, animation, np, plt
106
+
107
+
108
+ @app.cell
109
+ def _():
110
+ class Block:
111
+ def __init__(self, mass, velocity, position, size=1.0):
112
+ self.mass = mass
113
+ self.velocity = velocity
114
+ self.position = position
115
+ self.size = size
116
+
117
+ def update(self, dt):
118
+ self.position += self.velocity * dt
119
+
120
+ def collide(self, other):
121
+ # Calculate velocities after elastic collision
122
+ m1, m2 = self.mass, other.mass
123
+ v1, v2 = self.velocity, other.velocity
124
+
125
+ new_v1 = (m1 - m2) / (m1 + m2) * v1 + (2 * m2) / (m1 + m2) * v2
126
+ new_v2 = (2 * m1) / (m1 + m2) * v1 + (m2 - m1) / (m1 + m2) * v2
127
+
128
+ self.velocity = new_v1
129
+ other.velocity = new_v2
130
+
131
+ return 1 # Return 1 collision
132
+ return (Block,)
133
+
134
+
135
+ @app.cell
136
+ def check_collisions():
137
+ def check_collisions(small_block, big_block, wall_pos=0):
138
+ collisions = 0
139
+
140
+ # Check for collision between blocks
141
+ if small_block.position + small_block.size > big_block.position:
142
+ small_block.position = big_block.position - small_block.size
143
+ collisions += small_block.collide(big_block)
144
+
145
+ # Check for collision with the wall
146
+ if small_block.position < wall_pos:
147
+ small_block.position = wall_pos
148
+ small_block.velocity *= -1
149
+ collisions += 1
150
+
151
+ return collisions
152
+ return (check_collisions,)
153
+
154
+
155
+ @app.cell
156
+ def _(Block, check_collisions, create_animation):
157
+ def simulate_collisions(mass_ratio, total_time=15, dt=0.001, animate=True):
158
+ # Initialize blocks
159
+ small_block = Block(mass=1, velocity=0, position=2)
160
+ big_block = Block(mass=mass_ratio, velocity=-0.5, position=4)
161
+
162
+ # Simulation variables
163
+ time = 0
164
+ collision_count = 0
165
+
166
+ # For animation
167
+ times = []
168
+ small_positions = []
169
+ big_positions = []
170
+ collision_counts = []
171
+
172
+ # Run simulation
173
+ while time < total_time:
174
+ # Update positions
175
+ small_block.update(dt)
176
+ big_block.update(dt)
177
+
178
+ # Check for and handle collisions
179
+ new_collisions = check_collisions(small_block, big_block)
180
+ collision_count += new_collisions
181
+
182
+ # Store data for animation
183
+ times.append(time)
184
+ small_positions.append(small_block.position)
185
+ big_positions.append(big_block.position)
186
+ collision_counts.append(collision_count)
187
+
188
+ time += dt
189
+
190
+
191
+ print(f"Mass ratio: {mass_ratio}, Total collisions: {collision_count}")
192
+
193
+ if animate:
194
+ axis, ani = create_animation(
195
+ times, small_positions, big_positions, collision_counts, mass_ratio
196
+ )
197
+ else:
198
+ axis, ani = None
199
+
200
+ return axis, ani, collision_count
201
+ return (simulate_collisions,)
202
+
203
+
204
+ @app.cell
205
+ def _(Rectangle, animation, plt):
206
+ def create_animation(
207
+ times, small_positions, big_positions, collision_counts, mass_ratio
208
+ ):
209
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
210
+
211
+ # Setup for blocks visualization
212
+ ax1.set_xlim(-1, 10)
213
+ ax1.set_ylim(-1, 2)
214
+ ax1.set_xlabel("Position")
215
+ ax1.set_title(f"Block Collisions (Mass Ratio = {mass_ratio})")
216
+ wall = plt.Line2D([0, 0], [-1, 2], color="black", linewidth=3)
217
+ ax1.add_line(wall)
218
+
219
+ small_block = Rectangle((small_positions[0], 0), 1, 1, color="blue")
220
+ big_block = Rectangle((big_positions[0], 0), 1, 1, color="red")
221
+ ax1.add_patch(small_block)
222
+ ax1.add_patch(big_block)
223
+
224
+ # Add weight labels for each block
225
+ small_label = ax1.text(
226
+ small_positions[0] + 0.5,
227
+ 1.2,
228
+ f"{1}kg",
229
+ ha="center",
230
+ va="center",
231
+ color="blue",
232
+ fontweight="bold",
233
+ )
234
+ big_label = ax1.text(
235
+ big_positions[0] + 0.5,
236
+ 1.2,
237
+ f"{mass_ratio}kg",
238
+ ha="center",
239
+ va="center",
240
+ color="red",
241
+ fontweight="bold",
242
+ )
243
+
244
+ # Setup for collision count
245
+ ax2.set_xlim(0, times[-1])
246
+ # ax2.set_ylim(0, collision_counts[-1] * 1.1)
247
+ ax2.set_ylim(0, collision_counts[-1] * 1.1)
248
+ ax2.set_xlabel("Time")
249
+ ax2.set_ylabel("# Collisions:")
250
+ ax2.set_yscale("symlog")
251
+ (collision_line,) = ax2.plot([], [], "g-")
252
+
253
+ # Add text for collision count
254
+ collision_text = ax2.text(
255
+ 0.02, 0.9, "", transform=ax2.transAxes, fontsize="x-large"
256
+ )
257
+
258
+ def init():
259
+ small_block.set_xy((small_positions[0], 0))
260
+ big_block.set_xy((big_positions[0], 0))
261
+ small_label.set_position((small_positions[0] + 0.5, 1.2))
262
+ big_label.set_position((big_positions[0] + 0.5, 1.2))
263
+ collision_line.set_data([], [])
264
+ collision_text.set_text("")
265
+ return small_block, big_block, collision_line, collision_text
266
+
267
+ frame_step = 300
268
+
269
+ def animate(i):
270
+ # Speed up animation but ensure we reach the final frame
271
+ frame_index = min(i * frame_step, len(times) - 1)
272
+
273
+ small_block.set_xy((small_positions[frame_index], 0))
274
+ big_block.set_xy((big_positions[frame_index], 0))
275
+
276
+ # Update the weight labels to follow the blocks
277
+ small_label.set_position((small_positions[frame_index] + 0.5, 1.2))
278
+ big_label.set_position((big_positions[frame_index] + 0.5, 1.2))
279
+
280
+ # Show data up to the current frame
281
+ collision_line.set_data(
282
+ times[: frame_index + 1], collision_counts[: frame_index + 1]
283
+ )
284
+
285
+ # For the last frame, show the final collision count
286
+ if frame_index >= len(times) - 1:
287
+ collision_text.set_text(
288
+ f"# Collisions: {collision_counts[-1]}"
289
+ )
290
+ else:
291
+ collision_text.set_text(
292
+ f"# Collisions: {collision_counts[frame_index]}"
293
+ )
294
+
295
+ return (
296
+ small_block,
297
+ big_block,
298
+ small_label,
299
+ big_label,
300
+ collision_line,
301
+ collision_text,
302
+ )
303
+
304
+ plt.tight_layout()
305
+
306
+ frames = max(1, len(times) // frame_step) # Ensure at least 1 frame
307
+ ani = animation.FuncAnimation(
308
+ fig,
309
+ animate,
310
+ frames=frames + 1, # +1 to ensure we reach the end
311
+ init_func=init,
312
+ blit=True,
313
+ interval=30,
314
+ )
315
+
316
+ plt.tight_layout()
317
+ return plt.gca(), ani
318
+
319
+ # Uncomment to save animation
320
+ # ani.save('pi_collisions.mp4', writer='ffmpeg', fps=30)
321
+ return (create_animation,)
322
 
323
 
324
  if __name__ == "__main__":