File size: 24,433 Bytes
7e35597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "marimo",
#     "matplotlib==3.10.0",
#     "numpy==2.2.3",
#     "scipy==1.15.2",
# ]
# ///

import marimo

__generated_with = "0.11.17"
app = marimo.App(width="medium", app_title="Probability Mass Functions")


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        # Probability Mass Functions

        _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/pmf/), by Stanford professor Chris Piech._

        For a random variable, the most important thing to know is: how likely is each outcome? For a discrete random variable, this information is called the "**Probability Mass Function**". The probability mass function (PMF) provides the "mass" (i.e. amount) of "probability" for each possible assignment of the random variable.

        Formally, the Probability Mass Function is a mapping between the values that the random variable could take on and the probability of the random variable taking on said value. In mathematics, we call these associations functions. There are many different ways of representing functions: you can write an equation, you can make a graph, you can even store many samples in a list.
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Properties of a PMF

        For a function $p_X(x)$ to be a valid PMF, it must satisfy:

        1. **Non-negativity**: $p_X(x) \geq 0$ for all $x$
        2. **Unit total probability**: $\sum_x p_X(x) = 1$

        ### Probabilities Must Sum to 1

        For a variable (call it $X$) to be a proper random variable, it must be the case that if you summed up the values of $P(X=x)$ for all possible values $x$ that $X$ can take on, the result must be 1:

        $$\sum_x P(X=x) = 1$$

        This is because a random variable taking on a value is an event (for example $X=3$). Each of those events is mutually exclusive because a random variable will take on exactly one value. Those mutually exclusive cases define an entire sample space. Why? Because $X$ must take on some value.
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## PMFs as Graphs

        Let's start by looking at PMFs as graphs where the $x$-axis is the values that the random variable could take on and the $y$-axis is the probability of the random variable taking on said value.

        In the following example, we show two PMFs:

        - On the left: PMF for the random variable $X$ = the value of a single six-sided die roll
        - On the right: PMF for the random variable $Y$ = value of the sum of two dice rolls
        """
    )
    return


@app.cell(hide_code=True)
def _(np, plt):
    # Single die PMF
    single_die_values = np.arange(1, 7)
    single_die_probs = np.ones(6) / 6

    # Two dice sum PMF
    two_dice_values = np.arange(2, 13)
    two_dice_probs = []

    for dice_sum in two_dice_values:
        if dice_sum <= 7:
            dice_prob = (dice_sum-1) / 36
        else:
            dice_prob = (13-dice_sum) / 36
        two_dice_probs.append(dice_prob)

    # Create side-by-side plots
    dice_fig, (dice_ax1, dice_ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Single die plot
    dice_ax1.bar(single_die_values, single_die_probs, width=0.4)
    dice_ax1.set_xticks(single_die_values)
    dice_ax1.set_xlabel('Value of die roll (x)')
    dice_ax1.set_ylabel('Probability: P(X = x)')
    dice_ax1.set_title('PMF of a Single Die Roll')
    dice_ax1.grid(alpha=0.3)

    # Two dice sum plot
    dice_ax2.bar(two_dice_values, two_dice_probs, width=0.4)
    dice_ax2.set_xticks(two_dice_values)
    dice_ax2.set_xlabel('Sum of two dice (y)')
    dice_ax2.set_ylabel('Probability: P(Y = y)')
    dice_ax2.set_title('PMF of Sum of Two Dice')
    dice_ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.gca()
    return (
        dice_ax1,
        dice_ax2,
        dice_fig,
        dice_prob,
        dice_sum,
        single_die_probs,
        single_die_values,
        two_dice_probs,
        two_dice_values,
    )


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        The information provided in these graphs shows the likelihood of a random variable taking on different values.

        In the graph on the right, the value "6" on the $x$-axis is associated with the probability $\frac{5}{36}$ on the $y$-axis. This $x$-axis refers to the event "the sum of two dice is 6" or $Y = 6$. The $y$-axis tells us that the probability of that event is $\frac{5}{36}$. In full: $P(Y = 6) = \frac{5}{36}$.

        The value "2" is associated with "$\frac{1}{36}$" which tells us that, $P(Y = 2) = \frac{1}{36}$, the probability that two dice sum to 2 is $\frac{1}{36}$. There is no value associated with "1" because the sum of two dice cannot be 1.
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## PMFs as Equations

        Here is the exact same information in equation form:

        For a single die roll $X$:
        $$P(X=x) = \frac{1}{6} \quad \text{ if } 1 \leq x \leq 6$$

        For the sum of two dice $Y$:
        $$P(Y=y) = \begin{cases}
        \frac{(y-1)}{36} & \text{ if } 2 \leq y \leq 7\\
        \frac{(13-y)}{36} & \text{ if } 8 \leq y \leq 12
        \end{cases}$$

        Let's implement the PMF for $Y$, the sum of two dice, in Python code:
        """
    )
    return


@app.cell
def _():
    def pmf_sum_two_dice(y_val):
        """Returns the probability that the sum of two dice is y"""
        if y_val < 2 or y_val > 12:
            return 0
        if y_val <= 7:
            return (y_val-1) / 36
        else:
            return (13-y_val) / 36

    # Test the function for a few values
    test_values = [1, 2, 7, 12, 13]
    for test_y in test_values:
        print(f"P(Y = {test_y}) = {pmf_sum_two_dice(test_y)}")
    return pmf_sum_two_dice, test_values, test_y


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""Now, let's verify that our PMF satisfies the property that the sum of all probabilities equals 1:""")
    return


@app.cell
def _(pmf_sum_two_dice):
    # Verify that probabilities sum to 1
    verify_total_prob = sum(pmf_sum_two_dice(y_val) for y_val in range(2, 13))
    # Round to 10 decimal places to handle floating-point precision
    verify_total_prob_rounded = round(verify_total_prob, 10)
    print(f"Sum of all probabilities: {verify_total_prob_rounded}")
    return verify_total_prob, verify_total_prob_rounded


@app.cell(hide_code=True)
def _(plt, pmf_sum_two_dice):
    # Create a visual verification
    verify_y_values = list(range(2, 13))
    verify_probabilities = [pmf_sum_two_dice(y_val) for y_val in verify_y_values]

    plt.figure(figsize=(10, 4))
    plt.bar(verify_y_values, verify_probabilities, width=0.4)
    plt.xticks(verify_y_values)
    plt.xlabel('Sum of two dice (y)')
    plt.ylabel('Probability: P(Y = y)')
    plt.title('PMF of Sum of Two Dice (Total Probability = 1)')
    plt.grid(alpha=0.3)

    # Add probability values on top of bars
    for verify_i, verify_prob in enumerate(verify_probabilities):
        plt.text(verify_y_values[verify_i], verify_prob + 0.001, f'{verify_prob:.3f}', ha='center')

    plt.gca()  # Return the current axes to ensure proper display
    return verify_i, verify_prob, verify_probabilities, verify_y_values


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Data to Histograms to Probability Mass Functions

        One surprising way to store a likelihood function (recall that a PMF is the name of the likelihood function for discrete random variables) is simply a list of data. Let's simulate summing two dice many times to create an empirical PMF:
        """
    )
    return


@app.cell
def _(np):
    # Simulate rolling two dice many times
    sim_num_trials = 10000
    np.random.seed(42)  # For reproducibility

    # Generate random dice rolls
    sim_die1 = np.random.randint(1, 7, size=sim_num_trials)
    sim_die2 = np.random.randint(1, 7, size=sim_num_trials)

    # Calculate the sum
    sim_dice_sums = sim_die1 + sim_die2

    # Display a small sample of the data
    print(f"First 20 dice sums: {sim_dice_sums[:20]}")
    print(f"Total number of trials: {sim_num_trials}")
    return sim_dice_sums, sim_die1, sim_die2, sim_num_trials


@app.cell(hide_code=True)
def _(collections, np, plt, sim_dice_sums):
    # Count the frequency of each sum
    sim_counter = collections.Counter(sim_dice_sums)

    # Sort the values
    sim_sorted_values = sorted(sim_counter.keys())

    # Calculate the empirical PMF
    sim_empirical_pmf = [sim_counter[x] / len(sim_dice_sums) for x in sim_sorted_values]

    # Calculate the theoretical PMF
    sim_theoretical_values = np.arange(2, 13)
    sim_theoretical_pmf = []
    for sim_y in sim_theoretical_values:
        if sim_y <= 7:
            sim_prob = (sim_y-1) / 36
        else:
            sim_prob = (13-sim_y) / 36
        sim_theoretical_pmf.append(sim_prob)

    # Create a comparison plot
    sim_fig, (sim_ax1, sim_ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Empirical PMF (normalized histogram)
    sim_ax1.bar(sim_sorted_values, sim_empirical_pmf, width=0.4)
    sim_ax1.set_xticks(sim_sorted_values)
    sim_ax1.set_xlabel('Sum of two dice')
    sim_ax1.set_ylabel('Empirical Probability')
    sim_ax1.set_title(f'Empirical PMF from {len(sim_dice_sums)} Trials')
    sim_ax1.grid(alpha=0.3)

    # Theoretical PMF
    sim_ax2.bar(sim_theoretical_values, sim_theoretical_pmf, width=0.4)
    sim_ax2.set_xticks(sim_theoretical_values)
    sim_ax2.set_xlabel('Sum of two dice')
    sim_ax2.set_ylabel('Theoretical Probability')
    sim_ax2.set_title('Theoretical PMF')
    sim_ax2.grid(alpha=0.3)

    plt.tight_layout()

    # Let's also look at the raw counts (histogram)
    plt.figure(figsize=(10, 4))
    sim_counts = [sim_counter[x] for x in sim_sorted_values]
    plt.bar(sim_sorted_values, sim_counts, width=0.4)
    plt.xticks(sim_sorted_values)
    plt.xlabel('Sum of two dice')
    plt.ylabel('Frequency')
    plt.title('Histogram of Dice Sum Frequencies')
    plt.grid(alpha=0.3)

    # Add count values on top of bars
    for sim_i, sim_count in enumerate(sim_counts):
        plt.text(sim_sorted_values[sim_i], sim_count + 19, str(sim_count), ha='center')

    plt.gca()  # Return the current axes to ensure proper display
    return (
        sim_ax1,
        sim_ax2,
        sim_count,
        sim_counter,
        sim_counts,
        sim_empirical_pmf,
        sim_fig,
        sim_i,
        sim_prob,
        sim_sorted_values,
        sim_theoretical_pmf,
        sim_theoretical_values,
        sim_y,
    )


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        A normalized histogram (where each value is divided by the length of your data list) is an approximation of the PMF. For a dataset of discrete numbers, a histogram shows the count of each value. By the definition of probability, if you divide this count by the number of experiments run, you arrive at an approximation of the probability of the event $P(Y=y)$.

        Let's look at a specific example. If we want to approximate $P(Y=3)$ (the probability that the sum of two dice is 3), we can count the number of times that "3" occurs in our data and divide by the total number of trials:
        """
    )
    return


@app.cell
def _(sim_counter, sim_dice_sums):
    # Calculate P(Y=3) empirically
    sim_count_of_3 = sim_counter[3]
    sim_empirical_prob = sim_count_of_3 / len(sim_dice_sums)

    # Calculate P(Y=3) theoretically
    sim_theoretical_prob = 2/36  # There are 2 ways to get a sum of 3 out of 36 possible outcomes

    print(f"Count of sum=3: {sim_count_of_3}")
    print(f"Empirical P(Y=3): {sim_count_of_3}/{len(sim_dice_sums)} = {sim_empirical_prob:.4f}")
    print(f"Theoretical P(Y=3): 2/36 = {sim_theoretical_prob:.4f}")
    print(f"Difference: {abs(sim_empirical_prob - sim_theoretical_prob):.4f}")
    return sim_count_of_3, sim_empirical_prob, sim_theoretical_prob


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        As we can see, with a large number of trials, the empirical PMF becomes a very good approximation of the theoretical PMF. This is an example of the [Law of Large Numbers](https://en.wikipedia.org/wiki/Law_of_large_numbers) in action.

        ## Interactive Example: Exploring PMFs

        Let's create an interactive tool to explore different PMFs:
        """
    )
    return


@app.cell
def _(dist_param1, dist_param2, dist_selection, mo):
    mo.hstack([dist_selection, dist_param1, dist_param2], justify="space-around")
    return


@app.cell(hide_code=True)
def _(mo):
    dist_selection = mo.ui.dropdown(
        options=[
            "bernoulli",
            "binomial",
            "geometric",
            "poisson"
        ],
        value="bernoulli",
        label="Select a distribution"
    )

    # Parameters for different distributions
    dist_param1 = mo.ui.slider(
        start=0.05, 
        stop=0.95, 
        step=0.05, 
        value=0.5, 
        label="p (success probability)"
    )

    dist_param2 = mo.ui.slider(
        start=1, 
        stop=20, 
        step=1, 
        value=10, 
        label="n (trials) or λ (rate)"
    )
    return dist_param1, dist_param2, dist_selection


@app.cell(hide_code=True)
def _(dist_param1, dist_param2, dist_selection, np, plt, stats):
    # Set up the plot based on the selected distribution
    if dist_selection.value == "bernoulli":
        # Bernoulli distribution
        dist_p = dist_param1.value
        dist_x_values = np.array([0, 1])
        dist_pmf_values = [1-dist_p, dist_p]
        dist_title = f"Bernoulli PMF (p = {dist_p:.2f})"
        dist_x_label = "Outcome (0 = Failure, 1 = Success)"
        dist_max_x = 1

    elif dist_selection.value == "binomial":
        # Binomial distribution
        dist_n = int(dist_param2.value)
        dist_p = dist_param1.value
        dist_x_values = np.arange(0, dist_n+1)
        dist_pmf_values = stats.binom.pmf(dist_x_values, dist_n, dist_p)
        dist_title = f"Binomial PMF (n = {dist_n}, p = {dist_p:.2f})"
        dist_x_label = "Number of Successes"
        dist_max_x = dist_n

    elif dist_selection.value == "geometric":
        # Geometric distribution
        dist_p = dist_param1.value
        dist_max_x = min(int(5/dist_p), 50)  # Limit the range for visualization
        dist_x_values = np.arange(1, dist_max_x+1)
        dist_pmf_values = stats.geom.pmf(dist_x_values, dist_p)
        dist_title = f"Geometric PMF (p = {dist_p:.2f})"
        dist_x_label = "Number of Trials Until First Success"

    else:  # Poisson
        # Poisson distribution
        dist_lam = dist_param2.value
        dist_max_x = int(dist_lam*3) + 1  # Reasonable range for visualization
        dist_x_values = np.arange(0, dist_max_x)
        dist_pmf_values = stats.poisson.pmf(dist_x_values, dist_lam)
        dist_title = f"Poisson PMF (λ = {dist_lam})"
        dist_x_label = "Number of Events"

    # Create the plot
    plt.figure(figsize=(10, 5))

    # For discrete distributions, use stem plot for clarity
    dist_markerline, dist_stemlines, dist_baseline = plt.stem(
        dist_x_values, dist_pmf_values, markerfmt='o', basefmt=' '
    )
    plt.setp(dist_markerline, markersize=6)
    plt.setp(dist_stemlines, linewidth=1.5)

    # Add a bar plot for better visibility
    plt.bar(dist_x_values, dist_pmf_values, alpha=0.3, width=0.4)

    plt.xlabel(dist_x_label)
    plt.ylabel("Probability: P(X = x)")
    plt.title(dist_title)
    plt.grid(alpha=0.3)

    # Calculate and display expected value and variance
    if dist_selection.value == "bernoulli":
        dist_mean = dist_p
        dist_variance = dist_p * (1-dist_p)
    elif dist_selection.value == "binomial":
        dist_mean = dist_n * dist_p
        dist_variance = dist_n * dist_p * (1-dist_p)
    elif dist_selection.value == "geometric":
        dist_mean = 1/dist_p
        dist_variance = (1-dist_p)/(dist_p**2)
    else:  # Poisson
        dist_mean = dist_lam
        dist_variance = dist_lam

    dist_std_dev = np.sqrt(dist_variance)

    # Add text with distribution properties
    dist_props_text = (
        f"Mean: {dist_mean:.3f}\n"
        f"Variance: {dist_variance:.3f}\n"
        f"Std Dev: {dist_std_dev:.3f}\n"
        f"Sum of probabilities: {sum(dist_pmf_values):.6f}"
    )

    plt.text(0.95, 0.95, dist_props_text,
             transform=plt.gca().transAxes,
             verticalalignment='top',
             horizontalalignment='right',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    plt.gca()  # Return the current axes to ensure proper display
    return (
        dist_baseline,
        dist_lam,
        dist_markerline,
        dist_max_x,
        dist_mean,
        dist_n,
        dist_p,
        dist_pmf_values,
        dist_props_text,
        dist_std_dev,
        dist_stemlines,
        dist_title,
        dist_variance,
        dist_x_label,
        dist_x_values,
    )


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Expected Value from a PMF

        The expected value (or mean) of a discrete random variable is calculated using its PMF:

        $$E[X] = \sum_x x \cdot p_X(x)$$

        This represents the long-run average value of the random variable.
        """
    )
    return


@app.cell
def _(dist_pmf_values, dist_x_values):
    def calc_expected_value(x_values, pmf_values):
        """Calculate the expected value of a discrete random variable."""
        return sum(x * p for x, p in zip(x_values, pmf_values))

    # Calculate expected value for the current distribution
    ev_dist_mean = calc_expected_value(dist_x_values, dist_pmf_values)

    print(f"Expected value: {ev_dist_mean:.4f}")
    return calc_expected_value, ev_dist_mean


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Variance from a PMF

        The variance measures the spread or dispersion of a random variable around its mean:

        $$\text{Var}(X) = E[(X - E[X])^2] = \sum_x (x - E[X])^2 \cdot p_X(x)$$

        An alternative formula is:

        $$\text{Var}(X) = E[X^2] - (E[X])^2 = \sum_x x^2 \cdot p_X(x) - \left(\sum_x x \cdot p_X(x)\right)^2$$
        """
    )
    return


@app.cell
def _(dist_pmf_values, dist_x_values, ev_dist_mean, np):
    def calc_variance(x_values, pmf_values, mean_value):
        """Calculate the variance of a discrete random variable."""
        return sum((x - mean_value)**2 * p for x, p in zip(x_values, pmf_values))

    # Calculate variance for the current distribution
    var_dist_var = calc_variance(dist_x_values, dist_pmf_values, ev_dist_mean)
    var_dist_std_dev = np.sqrt(var_dist_var)

    print(f"Variance: {var_dist_var:.4f}")
    print(f"Standard deviation: {var_dist_std_dev:.4f}")
    return calc_variance, var_dist_std_dev, var_dist_var


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## PMF vs. CDF

        The **Cumulative Distribution Function (CDF)** is related to the PMF but gives the probability that the random variable $X$ is less than or equal to a value $x$:

        $$F_X(x) = P(X \leq x) = \sum_{k \leq x} p_X(k)$$

        While the PMF gives the probability mass at each point, the CDF accumulates these probabilities.
        """
    )
    return


@app.cell(hide_code=True)
def _(dist_pmf_values, dist_x_values, np, plt):
    # Calculate the CDF from the PMF
    cdf_dist_values = np.cumsum(dist_pmf_values)

    # Create a plot comparing PMF and CDF
    cdf_fig, (cdf_ax1, cdf_ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # PMF plot
    cdf_ax1.bar(dist_x_values, dist_pmf_values, width=0.4, alpha=0.7)
    cdf_ax1.set_xlabel('x')
    cdf_ax1.set_ylabel('P(X = x)')
    cdf_ax1.set_title('Probability Mass Function (PMF)')
    cdf_ax1.grid(alpha=0.3)

    # CDF plot - using step function with 'post' style for proper discrete representation
    cdf_ax2.step(dist_x_values, cdf_dist_values, where='post', linewidth=2, color='blue')
    cdf_ax2.scatter(dist_x_values, cdf_dist_values, s=50, color='blue')

    # Set appropriate limits for better visualization
    if len(dist_x_values) > 0:
        x_min = min(dist_x_values) - 0.5
        x_max = max(dist_x_values) + 0.5
        cdf_ax2.set_xlim(x_min, x_max)
        cdf_ax2.set_ylim(0, 1.05)  # CDF goes from 0 to 1

    cdf_ax2.set_xlabel('x')
    cdf_ax2.set_ylabel('P(X ≤ x)')
    cdf_ax2.set_title('Cumulative Distribution Function (CDF)')
    cdf_ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.gca()  # Return the current axes to ensure proper display
    return cdf_ax1, cdf_ax2, cdf_dist_values, cdf_fig, x_max, x_min


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        The graphs above illustrate the key difference between PMF and CDF:

        - **PMF (left)**: Shows the probability of the random variable taking each specific value: P(X = x)
        - **CDF (right)**: Shows the probability of the random variable being less than or equal to each value: P(X ≤ x)

        The CDF at any point is the sum of all PMF values up to and including that point. This is why the CDF is always non-decreasing and eventually reaches 1. For discrete distributions like this one, the CDF forms a step function that jumps at each value in the support of the random variable.
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Test Your Understanding

        Choose what you believe are the correct options in the questions below:

        <details>
        <summary>If X is a discrete random variable with PMF p(x), then p(x) must always be less than 1</summary>
        ❌ False! While most values in a PMF are typically less than 1, a PMF can have p(x) = 1 for a specific value if the random variable always takes that value (with 100% probability).
        </details>

        <details>
        <summary>The sum of all probabilities in a PMF must equal exactly 1</summary>
        ✅ True! This is a fundamental property of any valid PMF. The total probability across all possible values must be 1, as the random variable must take some value.
        </details>

        <details>
        <summary>A PMF can be estimated from data by creating a normalized histogram</summary>
        ✅ True! Counting the frequency of each value and dividing by the total number of observations gives an empirical PMF.
        </details>

        <details>
        <summary>The expected value of a discrete random variable is always one of the possible values of the variable</summary>
        ❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
        </details>
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Practical Applications of PMFs

        PMFs pop up everywhere - network engineers use them to model traffic patterns, reliability teams predict equipment failures, and marketers analyze purchase behavior. In finance, they help price options; in gaming, they're behind every dice roll. Machine learning algorithms like Naive Bayes rely on them, and they're essential for modeling rare events like genetic mutations or system failures.
        """
    )
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(
        r"""
        ## Key Takeaways

        PMFs give us the probability picture for discrete random variables - they tell us how likely each value is, must be non-negative, and always sum to 1. We can write them as equations, draw them as graphs, or estimate them from data. They're the foundation for calculating expected values and variances, which we'll explore in our next notebook on Expectation, where we'll learn how to summarize random variables with a single, most "expected" value.
        """
    )
    return


@app.cell
def _():
    import marimo as mo
    return (mo,)


@app.cell
def _():
    import matplotlib.pyplot as plt
    import numpy as np
    from scipy import stats
    import collections
    return collections, np, plt, stats


if __name__ == "__main__":
    app.run()