Haleshot commited on
Commit
7e35597
·
unverified ·
1 Parent(s): e338d9a

Add Probability Mass Functions (PMF) notebook

Browse files

This notebook introduces the concept of Probability Mass Functions (PMFs) with relevant visualizations (interactive too) and appropriate explanations.

probability/10_probability_mass_function.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.17"
14
+ app = marimo.App(width="medium", app_title="Probability Mass Functions")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Probability Mass Functions
22
+
23
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/pmf/), by Stanford professor Chris Piech._
24
+
25
+ For a random variable, the most important thing to know is: how likely is each outcome? For a discrete random variable, this information is called the "**Probability Mass Function**". The probability mass function (PMF) provides the "mass" (i.e. amount) of "probability" for each possible assignment of the random variable.
26
+
27
+ Formally, the Probability Mass Function is a mapping between the values that the random variable could take on and the probability of the random variable taking on said value. In mathematics, we call these associations functions. There are many different ways of representing functions: you can write an equation, you can make a graph, you can even store many samples in a list.
28
+ """
29
+ )
30
+ return
31
+
32
+
33
+ @app.cell(hide_code=True)
34
+ def _(mo):
35
+ mo.md(
36
+ r"""
37
+ ## Properties of a PMF
38
+
39
+ For a function $p_X(x)$ to be a valid PMF, it must satisfy:
40
+
41
+ 1. **Non-negativity**: $p_X(x) \geq 0$ for all $x$
42
+ 2. **Unit total probability**: $\sum_x p_X(x) = 1$
43
+
44
+ ### Probabilities Must Sum to 1
45
+
46
+ For a variable (call it $X$) to be a proper random variable, it must be the case that if you summed up the values of $P(X=x)$ for all possible values $x$ that $X$ can take on, the result must be 1:
47
+
48
+ $$\sum_x P(X=x) = 1$$
49
+
50
+ This is because a random variable taking on a value is an event (for example $X=3$). Each of those events is mutually exclusive because a random variable will take on exactly one value. Those mutually exclusive cases define an entire sample space. Why? Because $X$ must take on some value.
51
+ """
52
+ )
53
+ return
54
+
55
+
56
+ @app.cell(hide_code=True)
57
+ def _(mo):
58
+ mo.md(
59
+ r"""
60
+ ## PMFs as Graphs
61
+
62
+ Let's start by looking at PMFs as graphs where the $x$-axis is the values that the random variable could take on and the $y$-axis is the probability of the random variable taking on said value.
63
+
64
+ In the following example, we show two PMFs:
65
+
66
+ - On the left: PMF for the random variable $X$ = the value of a single six-sided die roll
67
+ - On the right: PMF for the random variable $Y$ = value of the sum of two dice rolls
68
+ """
69
+ )
70
+ return
71
+
72
+
73
+ @app.cell(hide_code=True)
74
+ def _(np, plt):
75
+ # Single die PMF
76
+ single_die_values = np.arange(1, 7)
77
+ single_die_probs = np.ones(6) / 6
78
+
79
+ # Two dice sum PMF
80
+ two_dice_values = np.arange(2, 13)
81
+ two_dice_probs = []
82
+
83
+ for dice_sum in two_dice_values:
84
+ if dice_sum <= 7:
85
+ dice_prob = (dice_sum-1) / 36
86
+ else:
87
+ dice_prob = (13-dice_sum) / 36
88
+ two_dice_probs.append(dice_prob)
89
+
90
+ # Create side-by-side plots
91
+ dice_fig, (dice_ax1, dice_ax2) = plt.subplots(1, 2, figsize=(12, 4))
92
+
93
+ # Single die plot
94
+ dice_ax1.bar(single_die_values, single_die_probs, width=0.4)
95
+ dice_ax1.set_xticks(single_die_values)
96
+ dice_ax1.set_xlabel('Value of die roll (x)')
97
+ dice_ax1.set_ylabel('Probability: P(X = x)')
98
+ dice_ax1.set_title('PMF of a Single Die Roll')
99
+ dice_ax1.grid(alpha=0.3)
100
+
101
+ # Two dice sum plot
102
+ dice_ax2.bar(two_dice_values, two_dice_probs, width=0.4)
103
+ dice_ax2.set_xticks(two_dice_values)
104
+ dice_ax2.set_xlabel('Sum of two dice (y)')
105
+ dice_ax2.set_ylabel('Probability: P(Y = y)')
106
+ dice_ax2.set_title('PMF of Sum of Two Dice')
107
+ dice_ax2.grid(alpha=0.3)
108
+
109
+ plt.tight_layout()
110
+ plt.gca()
111
+ return (
112
+ dice_ax1,
113
+ dice_ax2,
114
+ dice_fig,
115
+ dice_prob,
116
+ dice_sum,
117
+ single_die_probs,
118
+ single_die_values,
119
+ two_dice_probs,
120
+ two_dice_values,
121
+ )
122
+
123
+
124
+ @app.cell(hide_code=True)
125
+ def _(mo):
126
+ mo.md(
127
+ r"""
128
+ The information provided in these graphs shows the likelihood of a random variable taking on different values.
129
+
130
+ In the graph on the right, the value "6" on the $x$-axis is associated with the probability $\frac{5}{36}$ on the $y$-axis. This $x$-axis refers to the event "the sum of two dice is 6" or $Y = 6$. The $y$-axis tells us that the probability of that event is $\frac{5}{36}$. In full: $P(Y = 6) = \frac{5}{36}$.
131
+
132
+ The value "2" is associated with "$\frac{1}{36}$" which tells us that, $P(Y = 2) = \frac{1}{36}$, the probability that two dice sum to 2 is $\frac{1}{36}$. There is no value associated with "1" because the sum of two dice cannot be 1.
133
+ """
134
+ )
135
+ return
136
+
137
+
138
+ @app.cell(hide_code=True)
139
+ def _(mo):
140
+ mo.md(
141
+ r"""
142
+ ## PMFs as Equations
143
+
144
+ Here is the exact same information in equation form:
145
+
146
+ For a single die roll $X$:
147
+ $$P(X=x) = \frac{1}{6} \quad \text{ if } 1 \leq x \leq 6$$
148
+
149
+ For the sum of two dice $Y$:
150
+ $$P(Y=y) = \begin{cases}
151
+ \frac{(y-1)}{36} & \text{ if } 2 \leq y \leq 7\\
152
+ \frac{(13-y)}{36} & \text{ if } 8 \leq y \leq 12
153
+ \end{cases}$$
154
+
155
+ Let's implement the PMF for $Y$, the sum of two dice, in Python code:
156
+ """
157
+ )
158
+ return
159
+
160
+
161
+ @app.cell
162
+ def _():
163
+ def pmf_sum_two_dice(y_val):
164
+ """Returns the probability that the sum of two dice is y"""
165
+ if y_val < 2 or y_val > 12:
166
+ return 0
167
+ if y_val <= 7:
168
+ return (y_val-1) / 36
169
+ else:
170
+ return (13-y_val) / 36
171
+
172
+ # Test the function for a few values
173
+ test_values = [1, 2, 7, 12, 13]
174
+ for test_y in test_values:
175
+ print(f"P(Y = {test_y}) = {pmf_sum_two_dice(test_y)}")
176
+ return pmf_sum_two_dice, test_values, test_y
177
+
178
+
179
+ @app.cell(hide_code=True)
180
+ def _(mo):
181
+ mo.md(r"""Now, let's verify that our PMF satisfies the property that the sum of all probabilities equals 1:""")
182
+ return
183
+
184
+
185
+ @app.cell
186
+ def _(pmf_sum_two_dice):
187
+ # Verify that probabilities sum to 1
188
+ verify_total_prob = sum(pmf_sum_two_dice(y_val) for y_val in range(2, 13))
189
+ # Round to 10 decimal places to handle floating-point precision
190
+ verify_total_prob_rounded = round(verify_total_prob, 10)
191
+ print(f"Sum of all probabilities: {verify_total_prob_rounded}")
192
+ return verify_total_prob, verify_total_prob_rounded
193
+
194
+
195
+ @app.cell(hide_code=True)
196
+ def _(plt, pmf_sum_two_dice):
197
+ # Create a visual verification
198
+ verify_y_values = list(range(2, 13))
199
+ verify_probabilities = [pmf_sum_two_dice(y_val) for y_val in verify_y_values]
200
+
201
+ plt.figure(figsize=(10, 4))
202
+ plt.bar(verify_y_values, verify_probabilities, width=0.4)
203
+ plt.xticks(verify_y_values)
204
+ plt.xlabel('Sum of two dice (y)')
205
+ plt.ylabel('Probability: P(Y = y)')
206
+ plt.title('PMF of Sum of Two Dice (Total Probability = 1)')
207
+ plt.grid(alpha=0.3)
208
+
209
+ # Add probability values on top of bars
210
+ for verify_i, verify_prob in enumerate(verify_probabilities):
211
+ plt.text(verify_y_values[verify_i], verify_prob + 0.001, f'{verify_prob:.3f}', ha='center')
212
+
213
+ plt.gca() # Return the current axes to ensure proper display
214
+ return verify_i, verify_prob, verify_probabilities, verify_y_values
215
+
216
+
217
+ @app.cell(hide_code=True)
218
+ def _(mo):
219
+ mo.md(
220
+ r"""
221
+ ## Data to Histograms to Probability Mass Functions
222
+
223
+ One surprising way to store a likelihood function (recall that a PMF is the name of the likelihood function for discrete random variables) is simply a list of data. Let's simulate summing two dice many times to create an empirical PMF:
224
+ """
225
+ )
226
+ return
227
+
228
+
229
+ @app.cell
230
+ def _(np):
231
+ # Simulate rolling two dice many times
232
+ sim_num_trials = 10000
233
+ np.random.seed(42) # For reproducibility
234
+
235
+ # Generate random dice rolls
236
+ sim_die1 = np.random.randint(1, 7, size=sim_num_trials)
237
+ sim_die2 = np.random.randint(1, 7, size=sim_num_trials)
238
+
239
+ # Calculate the sum
240
+ sim_dice_sums = sim_die1 + sim_die2
241
+
242
+ # Display a small sample of the data
243
+ print(f"First 20 dice sums: {sim_dice_sums[:20]}")
244
+ print(f"Total number of trials: {sim_num_trials}")
245
+ return sim_dice_sums, sim_die1, sim_die2, sim_num_trials
246
+
247
+
248
+ @app.cell(hide_code=True)
249
+ def _(collections, np, plt, sim_dice_sums):
250
+ # Count the frequency of each sum
251
+ sim_counter = collections.Counter(sim_dice_sums)
252
+
253
+ # Sort the values
254
+ sim_sorted_values = sorted(sim_counter.keys())
255
+
256
+ # Calculate the empirical PMF
257
+ sim_empirical_pmf = [sim_counter[x] / len(sim_dice_sums) for x in sim_sorted_values]
258
+
259
+ # Calculate the theoretical PMF
260
+ sim_theoretical_values = np.arange(2, 13)
261
+ sim_theoretical_pmf = []
262
+ for sim_y in sim_theoretical_values:
263
+ if sim_y <= 7:
264
+ sim_prob = (sim_y-1) / 36
265
+ else:
266
+ sim_prob = (13-sim_y) / 36
267
+ sim_theoretical_pmf.append(sim_prob)
268
+
269
+ # Create a comparison plot
270
+ sim_fig, (sim_ax1, sim_ax2) = plt.subplots(1, 2, figsize=(12, 4))
271
+
272
+ # Empirical PMF (normalized histogram)
273
+ sim_ax1.bar(sim_sorted_values, sim_empirical_pmf, width=0.4)
274
+ sim_ax1.set_xticks(sim_sorted_values)
275
+ sim_ax1.set_xlabel('Sum of two dice')
276
+ sim_ax1.set_ylabel('Empirical Probability')
277
+ sim_ax1.set_title(f'Empirical PMF from {len(sim_dice_sums)} Trials')
278
+ sim_ax1.grid(alpha=0.3)
279
+
280
+ # Theoretical PMF
281
+ sim_ax2.bar(sim_theoretical_values, sim_theoretical_pmf, width=0.4)
282
+ sim_ax2.set_xticks(sim_theoretical_values)
283
+ sim_ax2.set_xlabel('Sum of two dice')
284
+ sim_ax2.set_ylabel('Theoretical Probability')
285
+ sim_ax2.set_title('Theoretical PMF')
286
+ sim_ax2.grid(alpha=0.3)
287
+
288
+ plt.tight_layout()
289
+
290
+ # Let's also look at the raw counts (histogram)
291
+ plt.figure(figsize=(10, 4))
292
+ sim_counts = [sim_counter[x] for x in sim_sorted_values]
293
+ plt.bar(sim_sorted_values, sim_counts, width=0.4)
294
+ plt.xticks(sim_sorted_values)
295
+ plt.xlabel('Sum of two dice')
296
+ plt.ylabel('Frequency')
297
+ plt.title('Histogram of Dice Sum Frequencies')
298
+ plt.grid(alpha=0.3)
299
+
300
+ # Add count values on top of bars
301
+ for sim_i, sim_count in enumerate(sim_counts):
302
+ plt.text(sim_sorted_values[sim_i], sim_count + 19, str(sim_count), ha='center')
303
+
304
+ plt.gca() # Return the current axes to ensure proper display
305
+ return (
306
+ sim_ax1,
307
+ sim_ax2,
308
+ sim_count,
309
+ sim_counter,
310
+ sim_counts,
311
+ sim_empirical_pmf,
312
+ sim_fig,
313
+ sim_i,
314
+ sim_prob,
315
+ sim_sorted_values,
316
+ sim_theoretical_pmf,
317
+ sim_theoretical_values,
318
+ sim_y,
319
+ )
320
+
321
+
322
+ @app.cell(hide_code=True)
323
+ def _(mo):
324
+ mo.md(
325
+ r"""
326
+ A normalized histogram (where each value is divided by the length of your data list) is an approximation of the PMF. For a dataset of discrete numbers, a histogram shows the count of each value. By the definition of probability, if you divide this count by the number of experiments run, you arrive at an approximation of the probability of the event $P(Y=y)$.
327
+
328
+ Let's look at a specific example. If we want to approximate $P(Y=3)$ (the probability that the sum of two dice is 3), we can count the number of times that "3" occurs in our data and divide by the total number of trials:
329
+ """
330
+ )
331
+ return
332
+
333
+
334
+ @app.cell
335
+ def _(sim_counter, sim_dice_sums):
336
+ # Calculate P(Y=3) empirically
337
+ sim_count_of_3 = sim_counter[3]
338
+ sim_empirical_prob = sim_count_of_3 / len(sim_dice_sums)
339
+
340
+ # Calculate P(Y=3) theoretically
341
+ sim_theoretical_prob = 2/36 # There are 2 ways to get a sum of 3 out of 36 possible outcomes
342
+
343
+ print(f"Count of sum=3: {sim_count_of_3}")
344
+ print(f"Empirical P(Y=3): {sim_count_of_3}/{len(sim_dice_sums)} = {sim_empirical_prob:.4f}")
345
+ print(f"Theoretical P(Y=3): 2/36 = {sim_theoretical_prob:.4f}")
346
+ print(f"Difference: {abs(sim_empirical_prob - sim_theoretical_prob):.4f}")
347
+ return sim_count_of_3, sim_empirical_prob, sim_theoretical_prob
348
+
349
+
350
+ @app.cell(hide_code=True)
351
+ def _(mo):
352
+ mo.md(
353
+ r"""
354
+ As we can see, with a large number of trials, the empirical PMF becomes a very good approximation of the theoretical PMF. This is an example of the [Law of Large Numbers](https://en.wikipedia.org/wiki/Law_of_large_numbers) in action.
355
+
356
+ ## Interactive Example: Exploring PMFs
357
+
358
+ Let's create an interactive tool to explore different PMFs:
359
+ """
360
+ )
361
+ return
362
+
363
+
364
+ @app.cell
365
+ def _(dist_param1, dist_param2, dist_selection, mo):
366
+ mo.hstack([dist_selection, dist_param1, dist_param2], justify="space-around")
367
+ return
368
+
369
+
370
+ @app.cell(hide_code=True)
371
+ def _(mo):
372
+ dist_selection = mo.ui.dropdown(
373
+ options=[
374
+ "bernoulli",
375
+ "binomial",
376
+ "geometric",
377
+ "poisson"
378
+ ],
379
+ value="bernoulli",
380
+ label="Select a distribution"
381
+ )
382
+
383
+ # Parameters for different distributions
384
+ dist_param1 = mo.ui.slider(
385
+ start=0.05,
386
+ stop=0.95,
387
+ step=0.05,
388
+ value=0.5,
389
+ label="p (success probability)"
390
+ )
391
+
392
+ dist_param2 = mo.ui.slider(
393
+ start=1,
394
+ stop=20,
395
+ step=1,
396
+ value=10,
397
+ label="n (trials) or λ (rate)"
398
+ )
399
+ return dist_param1, dist_param2, dist_selection
400
+
401
+
402
+ @app.cell(hide_code=True)
403
+ def _(dist_param1, dist_param2, dist_selection, np, plt, stats):
404
+ # Set up the plot based on the selected distribution
405
+ if dist_selection.value == "bernoulli":
406
+ # Bernoulli distribution
407
+ dist_p = dist_param1.value
408
+ dist_x_values = np.array([0, 1])
409
+ dist_pmf_values = [1-dist_p, dist_p]
410
+ dist_title = f"Bernoulli PMF (p = {dist_p:.2f})"
411
+ dist_x_label = "Outcome (0 = Failure, 1 = Success)"
412
+ dist_max_x = 1
413
+
414
+ elif dist_selection.value == "binomial":
415
+ # Binomial distribution
416
+ dist_n = int(dist_param2.value)
417
+ dist_p = dist_param1.value
418
+ dist_x_values = np.arange(0, dist_n+1)
419
+ dist_pmf_values = stats.binom.pmf(dist_x_values, dist_n, dist_p)
420
+ dist_title = f"Binomial PMF (n = {dist_n}, p = {dist_p:.2f})"
421
+ dist_x_label = "Number of Successes"
422
+ dist_max_x = dist_n
423
+
424
+ elif dist_selection.value == "geometric":
425
+ # Geometric distribution
426
+ dist_p = dist_param1.value
427
+ dist_max_x = min(int(5/dist_p), 50) # Limit the range for visualization
428
+ dist_x_values = np.arange(1, dist_max_x+1)
429
+ dist_pmf_values = stats.geom.pmf(dist_x_values, dist_p)
430
+ dist_title = f"Geometric PMF (p = {dist_p:.2f})"
431
+ dist_x_label = "Number of Trials Until First Success"
432
+
433
+ else: # Poisson
434
+ # Poisson distribution
435
+ dist_lam = dist_param2.value
436
+ dist_max_x = int(dist_lam*3) + 1 # Reasonable range for visualization
437
+ dist_x_values = np.arange(0, dist_max_x)
438
+ dist_pmf_values = stats.poisson.pmf(dist_x_values, dist_lam)
439
+ dist_title = f"Poisson PMF (λ = {dist_lam})"
440
+ dist_x_label = "Number of Events"
441
+
442
+ # Create the plot
443
+ plt.figure(figsize=(10, 5))
444
+
445
+ # For discrete distributions, use stem plot for clarity
446
+ dist_markerline, dist_stemlines, dist_baseline = plt.stem(
447
+ dist_x_values, dist_pmf_values, markerfmt='o', basefmt=' '
448
+ )
449
+ plt.setp(dist_markerline, markersize=6)
450
+ plt.setp(dist_stemlines, linewidth=1.5)
451
+
452
+ # Add a bar plot for better visibility
453
+ plt.bar(dist_x_values, dist_pmf_values, alpha=0.3, width=0.4)
454
+
455
+ plt.xlabel(dist_x_label)
456
+ plt.ylabel("Probability: P(X = x)")
457
+ plt.title(dist_title)
458
+ plt.grid(alpha=0.3)
459
+
460
+ # Calculate and display expected value and variance
461
+ if dist_selection.value == "bernoulli":
462
+ dist_mean = dist_p
463
+ dist_variance = dist_p * (1-dist_p)
464
+ elif dist_selection.value == "binomial":
465
+ dist_mean = dist_n * dist_p
466
+ dist_variance = dist_n * dist_p * (1-dist_p)
467
+ elif dist_selection.value == "geometric":
468
+ dist_mean = 1/dist_p
469
+ dist_variance = (1-dist_p)/(dist_p**2)
470
+ else: # Poisson
471
+ dist_mean = dist_lam
472
+ dist_variance = dist_lam
473
+
474
+ dist_std_dev = np.sqrt(dist_variance)
475
+
476
+ # Add text with distribution properties
477
+ dist_props_text = (
478
+ f"Mean: {dist_mean:.3f}\n"
479
+ f"Variance: {dist_variance:.3f}\n"
480
+ f"Std Dev: {dist_std_dev:.3f}\n"
481
+ f"Sum of probabilities: {sum(dist_pmf_values):.6f}"
482
+ )
483
+
484
+ plt.text(0.95, 0.95, dist_props_text,
485
+ transform=plt.gca().transAxes,
486
+ verticalalignment='top',
487
+ horizontalalignment='right',
488
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
489
+
490
+ plt.gca() # Return the current axes to ensure proper display
491
+ return (
492
+ dist_baseline,
493
+ dist_lam,
494
+ dist_markerline,
495
+ dist_max_x,
496
+ dist_mean,
497
+ dist_n,
498
+ dist_p,
499
+ dist_pmf_values,
500
+ dist_props_text,
501
+ dist_std_dev,
502
+ dist_stemlines,
503
+ dist_title,
504
+ dist_variance,
505
+ dist_x_label,
506
+ dist_x_values,
507
+ )
508
+
509
+
510
+ @app.cell(hide_code=True)
511
+ def _(mo):
512
+ mo.md(
513
+ r"""
514
+ ## Expected Value from a PMF
515
+
516
+ The expected value (or mean) of a discrete random variable is calculated using its PMF:
517
+
518
+ $$E[X] = \sum_x x \cdot p_X(x)$$
519
+
520
+ This represents the long-run average value of the random variable.
521
+ """
522
+ )
523
+ return
524
+
525
+
526
+ @app.cell
527
+ def _(dist_pmf_values, dist_x_values):
528
+ def calc_expected_value(x_values, pmf_values):
529
+ """Calculate the expected value of a discrete random variable."""
530
+ return sum(x * p for x, p in zip(x_values, pmf_values))
531
+
532
+ # Calculate expected value for the current distribution
533
+ ev_dist_mean = calc_expected_value(dist_x_values, dist_pmf_values)
534
+
535
+ print(f"Expected value: {ev_dist_mean:.4f}")
536
+ return calc_expected_value, ev_dist_mean
537
+
538
+
539
+ @app.cell(hide_code=True)
540
+ def _(mo):
541
+ mo.md(
542
+ r"""
543
+ ## Variance from a PMF
544
+
545
+ The variance measures the spread or dispersion of a random variable around its mean:
546
+
547
+ $$\text{Var}(X) = E[(X - E[X])^2] = \sum_x (x - E[X])^2 \cdot p_X(x)$$
548
+
549
+ An alternative formula is:
550
+
551
+ $$\text{Var}(X) = E[X^2] - (E[X])^2 = \sum_x x^2 \cdot p_X(x) - \left(\sum_x x \cdot p_X(x)\right)^2$$
552
+ """
553
+ )
554
+ return
555
+
556
+
557
+ @app.cell
558
+ def _(dist_pmf_values, dist_x_values, ev_dist_mean, np):
559
+ def calc_variance(x_values, pmf_values, mean_value):
560
+ """Calculate the variance of a discrete random variable."""
561
+ return sum((x - mean_value)**2 * p for x, p in zip(x_values, pmf_values))
562
+
563
+ # Calculate variance for the current distribution
564
+ var_dist_var = calc_variance(dist_x_values, dist_pmf_values, ev_dist_mean)
565
+ var_dist_std_dev = np.sqrt(var_dist_var)
566
+
567
+ print(f"Variance: {var_dist_var:.4f}")
568
+ print(f"Standard deviation: {var_dist_std_dev:.4f}")
569
+ return calc_variance, var_dist_std_dev, var_dist_var
570
+
571
+
572
+ @app.cell(hide_code=True)
573
+ def _(mo):
574
+ mo.md(
575
+ r"""
576
+ ## PMF vs. CDF
577
+
578
+ The **Cumulative Distribution Function (CDF)** is related to the PMF but gives the probability that the random variable $X$ is less than or equal to a value $x$:
579
+
580
+ $$F_X(x) = P(X \leq x) = \sum_{k \leq x} p_X(k)$$
581
+
582
+ While the PMF gives the probability mass at each point, the CDF accumulates these probabilities.
583
+ """
584
+ )
585
+ return
586
+
587
+
588
+ @app.cell(hide_code=True)
589
+ def _(dist_pmf_values, dist_x_values, np, plt):
590
+ # Calculate the CDF from the PMF
591
+ cdf_dist_values = np.cumsum(dist_pmf_values)
592
+
593
+ # Create a plot comparing PMF and CDF
594
+ cdf_fig, (cdf_ax1, cdf_ax2) = plt.subplots(1, 2, figsize=(12, 4))
595
+
596
+ # PMF plot
597
+ cdf_ax1.bar(dist_x_values, dist_pmf_values, width=0.4, alpha=0.7)
598
+ cdf_ax1.set_xlabel('x')
599
+ cdf_ax1.set_ylabel('P(X = x)')
600
+ cdf_ax1.set_title('Probability Mass Function (PMF)')
601
+ cdf_ax1.grid(alpha=0.3)
602
+
603
+ # CDF plot - using step function with 'post' style for proper discrete representation
604
+ cdf_ax2.step(dist_x_values, cdf_dist_values, where='post', linewidth=2, color='blue')
605
+ cdf_ax2.scatter(dist_x_values, cdf_dist_values, s=50, color='blue')
606
+
607
+ # Set appropriate limits for better visualization
608
+ if len(dist_x_values) > 0:
609
+ x_min = min(dist_x_values) - 0.5
610
+ x_max = max(dist_x_values) + 0.5
611
+ cdf_ax2.set_xlim(x_min, x_max)
612
+ cdf_ax2.set_ylim(0, 1.05) # CDF goes from 0 to 1
613
+
614
+ cdf_ax2.set_xlabel('x')
615
+ cdf_ax2.set_ylabel('P(X ≤ x)')
616
+ cdf_ax2.set_title('Cumulative Distribution Function (CDF)')
617
+ cdf_ax2.grid(alpha=0.3)
618
+
619
+ plt.tight_layout()
620
+ plt.gca() # Return the current axes to ensure proper display
621
+ return cdf_ax1, cdf_ax2, cdf_dist_values, cdf_fig, x_max, x_min
622
+
623
+
624
+ @app.cell(hide_code=True)
625
+ def _(mo):
626
+ mo.md(
627
+ r"""
628
+ The graphs above illustrate the key difference between PMF and CDF:
629
+
630
+ - **PMF (left)**: Shows the probability of the random variable taking each specific value: P(X = x)
631
+ - **CDF (right)**: Shows the probability of the random variable being less than or equal to each value: P(X ≤ x)
632
+
633
+ The CDF at any point is the sum of all PMF values up to and including that point. This is why the CDF is always non-decreasing and eventually reaches 1. For discrete distributions like this one, the CDF forms a step function that jumps at each value in the support of the random variable.
634
+ """
635
+ )
636
+ return
637
+
638
+
639
+ @app.cell(hide_code=True)
640
+ def _(mo):
641
+ mo.md(
642
+ r"""
643
+ ## Test Your Understanding
644
+
645
+ Choose what you believe are the correct options in the questions below:
646
+
647
+ <details>
648
+ <summary>If X is a discrete random variable with PMF p(x), then p(x) must always be less than 1</summary>
649
+ ❌ False! While most values in a PMF are typically less than 1, a PMF can have p(x) = 1 for a specific value if the random variable always takes that value (with 100% probability).
650
+ </details>
651
+
652
+ <details>
653
+ <summary>The sum of all probabilities in a PMF must equal exactly 1</summary>
654
+ ✅ True! This is a fundamental property of any valid PMF. The total probability across all possible values must be 1, as the random variable must take some value.
655
+ </details>
656
+
657
+ <details>
658
+ <summary>A PMF can be estimated from data by creating a normalized histogram</summary>
659
+ ✅ True! Counting the frequency of each value and dividing by the total number of observations gives an empirical PMF.
660
+ </details>
661
+
662
+ <details>
663
+ <summary>The expected value of a discrete random variable is always one of the possible values of the variable</summary>
664
+ ❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
665
+ </details>
666
+ """
667
+ )
668
+ return
669
+
670
+
671
+ @app.cell(hide_code=True)
672
+ def _(mo):
673
+ mo.md(
674
+ r"""
675
+ ## Practical Applications of PMFs
676
+
677
+ PMFs pop up everywhere - network engineers use them to model traffic patterns, reliability teams predict equipment failures, and marketers analyze purchase behavior. In finance, they help price options; in gaming, they're behind every dice roll. Machine learning algorithms like Naive Bayes rely on them, and they're essential for modeling rare events like genetic mutations or system failures.
678
+ """
679
+ )
680
+ return
681
+
682
+
683
+ @app.cell(hide_code=True)
684
+ def _(mo):
685
+ mo.md(
686
+ r"""
687
+ ## Key Takeaways
688
+
689
+ PMFs give us the probability picture for discrete random variables - they tell us how likely each value is, must be non-negative, and always sum to 1. We can write them as equations, draw them as graphs, or estimate them from data. They're the foundation for calculating expected values and variances, which we'll explore in our next notebook on Expectation, where we'll learn how to summarize random variables with a single, most "expected" value.
690
+ """
691
+ )
692
+ return
693
+
694
+
695
+ @app.cell
696
+ def _():
697
+ import marimo as mo
698
+ return (mo,)
699
+
700
+
701
+ @app.cell
702
+ def _():
703
+ import matplotlib.pyplot as plt
704
+ import numpy as np
705
+ from scipy import stats
706
+ import collections
707
+ return collections, np, plt, stats
708
+
709
+
710
+ if __name__ == "__main__":
711
+ app.run()