Spaces:
Running
Running
Add `Maximum Likelihood Estimation` notebook
Browse files
probability/19_maximum_likelihood_estimation.py
ADDED
@@ -0,0 +1,1236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.1",
|
6 |
+
# "scipy==1.15.2",
|
7 |
+
# "numpy==2.2.4",
|
8 |
+
# "polars==0.20.2",
|
9 |
+
# "plotly==5.18.0",
|
10 |
+
# ]
|
11 |
+
# ///
|
12 |
+
|
13 |
+
import marimo
|
14 |
+
|
15 |
+
__generated_with = "0.12.0"
|
16 |
+
app = marimo.App(width="medium", app_title="Maximum Likelihood Estimation")
|
17 |
+
|
18 |
+
|
19 |
+
@app.cell(hide_code=True)
|
20 |
+
def _(mo):
|
21 |
+
mo.md(
|
22 |
+
r"""
|
23 |
+
# Maximum Likelihood Estimation
|
24 |
+
|
25 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part5/mle/), by Stanford professor Chris Piech._
|
26 |
+
|
27 |
+
Maximum Likelihood Estimation (MLE) is a fundamental method in statistics for estimating parameters of a probability distribution. The central idea is elegantly simple: **choose the parameters that make the observed data most likely**.
|
28 |
+
|
29 |
+
In this notebook, we'll try to understand MLE, starting with the core concept of likelihood and how it differs from probability. We'll explore how to formulate MLE problems mathematically and then solve them for various common distributions. Along the way, I've included some interactive visualizations to help build your intuition about these concepts. You'll see how MLE applies to real-world scenarios like linear regression, and hopefully gain a deeper appreciation for why this technique is so widely used in statistics and machine learning. Think of MLE as detective work - we have some evidence (our data) and we're trying to figure out the most plausible explanation (our parameters) for what we've observed.
|
30 |
+
"""
|
31 |
+
)
|
32 |
+
return
|
33 |
+
|
34 |
+
|
35 |
+
@app.cell(hide_code=True)
|
36 |
+
def _(mo):
|
37 |
+
mo.md(
|
38 |
+
r"""
|
39 |
+
## Likelihood: The Core Concept
|
40 |
+
|
41 |
+
Before diving into MLE, we need to understand what "likelihood" means in a statistical context.
|
42 |
+
|
43 |
+
### Data and Parameters
|
44 |
+
|
45 |
+
Suppose we have collected some data $X_1, X_2, \ldots, X_n$ that are independent and identically distributed (IID). We assume these data points come from a specific type of distribution (like Normal, Bernoulli, etc.) with unknown parameters $\theta$.
|
46 |
+
|
47 |
+
### What is Likelihood?
|
48 |
+
|
49 |
+
Likelihood measures how probable our observed data is, given specific values of the parameters $\theta$.
|
50 |
+
|
51 |
+
- For **discrete** distributions: likelihood is the probability mass function (PMF) of our data
|
52 |
+
- For **continuous** distributions: likelihood is the probability density function (PDF) of our data
|
53 |
+
|
54 |
+
/// note
|
55 |
+
**Probability vs. Likelihood**
|
56 |
+
|
57 |
+
- **Probability**: Given parameters $\theta$, what's the chance of observing data $X$?
|
58 |
+
- **Likelihood**: Given observed data $X$, how likely are different parameter values $\theta$?
|
59 |
+
|
60 |
+
They use the same formula but different perspectives!
|
61 |
+
///
|
62 |
+
|
63 |
+
To simplify notation, we'll use $f(X=x|\Theta=\theta)$ to represent either the PMF or PDF of our data, conditioned on the parameters.
|
64 |
+
"""
|
65 |
+
)
|
66 |
+
return
|
67 |
+
|
68 |
+
|
69 |
+
@app.cell(hide_code=True)
|
70 |
+
def _(mo):
|
71 |
+
mo.md(
|
72 |
+
r"""
|
73 |
+
### The Likelihood Function
|
74 |
+
|
75 |
+
Since we assume our data points are independent, the likelihood of all our data is the product of the likelihoods of each individual data point:
|
76 |
+
|
77 |
+
$$L(\theta) = \prod_{i=1}^n f(X_i = x_i|\Theta = \theta)$$
|
78 |
+
|
79 |
+
This function $L(\theta)$ gives us the likelihood of observing our entire dataset for different parameter values $\theta$.
|
80 |
+
|
81 |
+
/// tip
|
82 |
+
**Key Insight**: Different parameter values produce different likelihoods for the same data. Better parameter values will make the observed data more likely.
|
83 |
+
///
|
84 |
+
"""
|
85 |
+
)
|
86 |
+
return
|
87 |
+
|
88 |
+
|
89 |
+
@app.cell(hide_code=True)
|
90 |
+
def _(mo):
|
91 |
+
mo.md(
|
92 |
+
r"""
|
93 |
+
## Maximum Likelihood Estimation
|
94 |
+
|
95 |
+
The core idea of MLE is to find the parameter values $\hat{\theta}$ that maximize the likelihood function:
|
96 |
+
|
97 |
+
$$\hat{\theta} = \underset{\theta}{\operatorname{argmax}} \, L(\theta)$$
|
98 |
+
|
99 |
+
The notation $\hat{\theta}$ represents our best estimate of the true parameters based on the observed data.
|
100 |
+
|
101 |
+
### Working with Log-Likelihood
|
102 |
+
|
103 |
+
In practice, we usually work with the **log-likelihood** instead of the likelihood directly. Since logarithm is a monotonically increasing function, the maximum of $L(\theta)$ occurs at the same value of $\theta$ as the maximum of $\log L(\theta)$.
|
104 |
+
|
105 |
+
Taking the logarithm transforms our product into a sum, which is much easier to work with:
|
106 |
+
|
107 |
+
$$LL(\theta) = \log L(\theta) = \log \prod_{i=1}^n f(X_i=x_i|\Theta = \theta) = \sum_{i=1}^n \log f(X_i = x_i|\Theta = \theta)$$
|
108 |
+
|
109 |
+
/// warning
|
110 |
+
Working with products of many small probabilities can lead to numerical underflow. Taking the logarithm converts these products to sums, which is numerically more stable.
|
111 |
+
///
|
112 |
+
"""
|
113 |
+
)
|
114 |
+
return
|
115 |
+
|
116 |
+
|
117 |
+
@app.cell(hide_code=True)
|
118 |
+
def _(mo):
|
119 |
+
mo.md(
|
120 |
+
r"""
|
121 |
+
### Finding the Maximum
|
122 |
+
|
123 |
+
To find the values of $\theta$ that maximize the log-likelihood, we typically:
|
124 |
+
|
125 |
+
1. Take the derivative of $LL(\theta)$ with respect to each parameter
|
126 |
+
2. Set each derivative equal to zero
|
127 |
+
3. Solve for the parameters
|
128 |
+
|
129 |
+
Let's see this approach in action with some common distributions.
|
130 |
+
"""
|
131 |
+
)
|
132 |
+
return
|
133 |
+
|
134 |
+
|
135 |
+
@app.cell(hide_code=True)
|
136 |
+
def _(mo):
|
137 |
+
mo.md(
|
138 |
+
r"""
|
139 |
+
## MLE for Bernoulli Distribution
|
140 |
+
|
141 |
+
Let's start with a simple example: estimating the parameter $p$ of a Bernoulli distribution.
|
142 |
+
|
143 |
+
### The Model
|
144 |
+
|
145 |
+
A Bernoulli distribution has a single parameter $p$ which represents the probability of success (getting a value of 1). Its probability mass function (PMF) can be written as:
|
146 |
+
|
147 |
+
$$f(x|p) = p^x(1-p)^{1-x}, \quad x \in \{0, 1\}$$
|
148 |
+
|
149 |
+
This elegant formula works because:
|
150 |
+
|
151 |
+
- When $x = 1$: $f(1|p) = p^1(1-p)^0 = p$
|
152 |
+
- When $x = 0$: $f(0|p) = p^0(1-p)^1 = 1-p$
|
153 |
+
|
154 |
+
### Deriving the MLE
|
155 |
+
|
156 |
+
Given $n$ independent Bernoulli trials $X_1, X_2, \ldots, X_n$, we want to find the value of $p$ that maximizes the likelihood of our observed data.
|
157 |
+
|
158 |
+
Step 1: Write the likelihood function
|
159 |
+
$$L(p) = \prod_{i=1}^n p^{x_i}(1-p)^{1-x_i}$$
|
160 |
+
|
161 |
+
Step 2: Take the logarithm to get the log-likelihood
|
162 |
+
$$\begin{align*}
|
163 |
+
LL(p) &= \sum_{i=1}^n \log(p^{x_i}(1-p)^{1-x_i}) \\
|
164 |
+
&= \sum_{i=1}^n \left[x_i \log(p) + (1-x_i)\log(1-p)\right] \\
|
165 |
+
&= \left(\sum_{i=1}^n x_i\right) \log(p) + \left(n - \sum_{i=1}^n x_i\right) \log(1-p) \\
|
166 |
+
&= Y\log(p) + (n-Y)\log(1-p)
|
167 |
+
\end{align*}$$
|
168 |
+
|
169 |
+
where $Y = \sum_{i=1}^n x_i$ is the total number of successes.
|
170 |
+
|
171 |
+
Step 3: Find the value of $p$ that maximizes $LL(p)$ by setting the derivative to zero
|
172 |
+
$$\begin{align*}
|
173 |
+
\frac{d\,LL(p)}{dp} &= \frac{Y}{p} - \frac{n-Y}{1-p} = 0 \\
|
174 |
+
\frac{Y}{p} &= \frac{n-Y}{1-p} \\
|
175 |
+
Y(1-p) &= p(n-Y) \\
|
176 |
+
Y - Yp &= pn - pY \\
|
177 |
+
Y &= pn \\
|
178 |
+
\hat{p} &= \frac{Y}{n} = \frac{\sum_{i=1}^n x_i}{n}
|
179 |
+
\end{align*}$$
|
180 |
+
|
181 |
+
/// tip
|
182 |
+
The MLE for the parameter $p$ in a Bernoulli distribution is simply the **sample mean** - the proportion of successes in our data!
|
183 |
+
///
|
184 |
+
"""
|
185 |
+
)
|
186 |
+
return
|
187 |
+
|
188 |
+
|
189 |
+
@app.cell(hide_code=True)
|
190 |
+
def _(controls):
|
191 |
+
controls.center()
|
192 |
+
return
|
193 |
+
|
194 |
+
|
195 |
+
@app.cell(hide_code=True)
|
196 |
+
def _(generate_button, mo, np, plt, sample_size_slider, true_p_slider):
|
197 |
+
# generate bernoulli samples when button is clicked
|
198 |
+
bernoulli_button_value = generate_button.value
|
199 |
+
|
200 |
+
# get parameter values
|
201 |
+
bernoulli_true_p = true_p_slider.value
|
202 |
+
bernoulli_n = sample_size_slider.value
|
203 |
+
|
204 |
+
# generate data
|
205 |
+
bernoulli_data = np.random.binomial(1, bernoulli_true_p, size=bernoulli_n)
|
206 |
+
bernoulli_Y = np.sum(bernoulli_data)
|
207 |
+
bernoulli_p_hat = bernoulli_Y / bernoulli_n
|
208 |
+
|
209 |
+
# create visualization
|
210 |
+
bernoulli_fig, (bernoulli_ax1, bernoulli_ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
211 |
+
|
212 |
+
# plot data histogram
|
213 |
+
bernoulli_ax1.hist(bernoulli_data, bins=[-0.5, 0.5, 1.5], rwidth=0.8, color='lightblue')
|
214 |
+
bernoulli_ax1.set_xticks([0, 1])
|
215 |
+
bernoulli_ax1.set_xticklabels(['Failure (0)', 'Success (1)'])
|
216 |
+
bernoulli_ax1.set_title(f'Bernoulli Data: {bernoulli_n} samples')
|
217 |
+
bernoulli_ax1.set_ylabel('Count')
|
218 |
+
bernoulli_y_counts = [bernoulli_n - bernoulli_Y, bernoulli_Y]
|
219 |
+
for bernoulli_idx, bernoulli_count in enumerate(bernoulli_y_counts):
|
220 |
+
bernoulli_ax1.text(bernoulli_idx, bernoulli_count/2, f"{bernoulli_count}",
|
221 |
+
ha='center', va='center',
|
222 |
+
color='white' if bernoulli_idx == 0 else 'black',
|
223 |
+
fontweight='bold')
|
224 |
+
|
225 |
+
# calculate log-likelihood function
|
226 |
+
bernoulli_p_values = np.linspace(0.01, 0.99, 100)
|
227 |
+
bernoulli_ll_values = np.zeros_like(bernoulli_p_values)
|
228 |
+
|
229 |
+
for bernoulli_i, bernoulli_p in enumerate(bernoulli_p_values):
|
230 |
+
bernoulli_ll_values[bernoulli_i] = bernoulli_Y * np.log(bernoulli_p) + (bernoulli_n - bernoulli_Y) * np.log(1 - bernoulli_p)
|
231 |
+
|
232 |
+
# plot log-likelihood
|
233 |
+
bernoulli_ax2.plot(bernoulli_p_values, bernoulli_ll_values, 'b-', linewidth=2)
|
234 |
+
bernoulli_ax2.axvline(x=bernoulli_p_hat, color='r', linestyle='--', label=f'MLE: $\\hat{{p}} = {bernoulli_p_hat:.3f}$')
|
235 |
+
bernoulli_ax2.axvline(x=bernoulli_true_p, color='g', linestyle='--', label=f'True: $p = {bernoulli_true_p:.3f}$')
|
236 |
+
bernoulli_ax2.set_xlabel('$p$ (probability of success)')
|
237 |
+
bernoulli_ax2.set_ylabel('Log-Likelihood')
|
238 |
+
bernoulli_ax2.set_title('Log-Likelihood Function')
|
239 |
+
bernoulli_ax2.legend()
|
240 |
+
|
241 |
+
plt.tight_layout()
|
242 |
+
plt.gca()
|
243 |
+
|
244 |
+
# Create markdown to explain the results
|
245 |
+
bernoulli_explanation = mo.md(
|
246 |
+
f"""
|
247 |
+
### Bernoulli MLE Results
|
248 |
+
|
249 |
+
**True parameter**: $p = {bernoulli_true_p:.3f}$
|
250 |
+
**Sample statistics**: {bernoulli_Y} successes out of {bernoulli_n} trials
|
251 |
+
**MLE estimate**: $\\hat{{p}} = \\frac{{{bernoulli_Y}}}{{{bernoulli_n}}} = {bernoulli_p_hat:.3f}$
|
252 |
+
|
253 |
+
The plot on the right shows the log-likelihood function $LL(p) = Y\\log(p) + (n-Y)\\log(1-p)$.
|
254 |
+
The red dashed line marks the maximum likelihood estimate $\\hat{{p}}$, and the green dashed line
|
255 |
+
shows the true parameter value.
|
256 |
+
|
257 |
+
/// note
|
258 |
+
Try increasing the sample size to see how the MLE estimate gets closer to the true parameter value!
|
259 |
+
///
|
260 |
+
"""
|
261 |
+
)
|
262 |
+
|
263 |
+
# Display plot and explanation together
|
264 |
+
mo.vstack([
|
265 |
+
bernoulli_fig,
|
266 |
+
bernoulli_explanation
|
267 |
+
])
|
268 |
+
return (
|
269 |
+
bernoulli_Y,
|
270 |
+
bernoulli_ax1,
|
271 |
+
bernoulli_ax2,
|
272 |
+
bernoulli_button_value,
|
273 |
+
bernoulli_count,
|
274 |
+
bernoulli_data,
|
275 |
+
bernoulli_explanation,
|
276 |
+
bernoulli_fig,
|
277 |
+
bernoulli_i,
|
278 |
+
bernoulli_idx,
|
279 |
+
bernoulli_ll_values,
|
280 |
+
bernoulli_n,
|
281 |
+
bernoulli_p,
|
282 |
+
bernoulli_p_hat,
|
283 |
+
bernoulli_p_values,
|
284 |
+
bernoulli_true_p,
|
285 |
+
bernoulli_y_counts,
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
+
@app.cell(hide_code=True)
|
290 |
+
def _(mo):
|
291 |
+
mo.md(
|
292 |
+
r"""
|
293 |
+
## MLE for Normal Distribution
|
294 |
+
|
295 |
+
Next, let's look at a more complex example: estimating the parameters $\mu$ and $\sigma^2$ of a Normal distribution.
|
296 |
+
|
297 |
+
### The Model
|
298 |
+
|
299 |
+
A Normal (Gaussian) distribution has two parameters:
|
300 |
+
- $\mu$: the mean
|
301 |
+
- $\sigma^2$: the variance
|
302 |
+
|
303 |
+
Its probability density function (PDF) is:
|
304 |
+
|
305 |
+
$$f(x|\mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$
|
306 |
+
|
307 |
+
### Deriving the MLE
|
308 |
+
|
309 |
+
Given $n$ independent samples $X_1, X_2, \ldots, X_n$ from a Normal distribution, we want to find the values of $\mu$ and $\sigma^2$ that maximize the likelihood of our observed data.
|
310 |
+
|
311 |
+
Step 1: Write the likelihood function
|
312 |
+
$$L(\mu, \sigma^2) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)$$
|
313 |
+
|
314 |
+
Step 2: Take the logarithm to get the log-likelihood
|
315 |
+
$$\begin{align*}
|
316 |
+
LL(\mu, \sigma^2) &= \log\prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right) \\
|
317 |
+
&= \sum_{i=1}^n \log\left[\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \mu)^2}{2\sigma^2}\right)\right] \\
|
318 |
+
&= \sum_{i=1}^n \left[-\frac{1}{2}\log(2\pi\sigma^2) - \frac{(x_i - \mu)^2}{2\sigma^2}\right] \\
|
319 |
+
&= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\sum_{i=1}^n (x_i - \mu)^2
|
320 |
+
\end{align*}$$
|
321 |
+
|
322 |
+
Step 3: Find the values of $\mu$ and $\sigma^2$ that maximize $LL(\mu, \sigma^2)$ by setting the partial derivatives to zero.
|
323 |
+
|
324 |
+
For $\mu$:
|
325 |
+
$$\begin{align*}
|
326 |
+
\frac{\partial LL(\mu, \sigma^2)}{\partial \mu} &= \frac{1}{\sigma^2}\sum_{i=1}^n (x_i - \mu) = 0 \\
|
327 |
+
\sum_{i=1}^n (x_i - \mu) &= 0 \\
|
328 |
+
\sum_{i=1}^n x_i &= n\mu \\
|
329 |
+
\hat{\mu} &= \frac{1}{n}\sum_{i=1}^n x_i
|
330 |
+
\end{align*}$$
|
331 |
+
|
332 |
+
For $\sigma^2$:
|
333 |
+
$$\begin{align*}
|
334 |
+
\frac{\partial LL(\mu, \sigma^2)}{\partial \sigma^2} &= -\frac{n}{2\sigma^2} + \frac{1}{2(\sigma^2)^2}\sum_{i=1}^n (x_i - \mu)^2 = 0 \\
|
335 |
+
\frac{n}{2\sigma^2} &= \frac{1}{2(\sigma^2)^2}\sum_{i=1}^n (x_i - \mu)^2 \\
|
336 |
+
n\sigma^2 &= \sum_{i=1}^n (x_i - \mu)^2 \\
|
337 |
+
\hat{\sigma}^2 &= \frac{1}{n}\sum_{i=1}^n (x_i - \hat{\mu})^2
|
338 |
+
\end{align*}$$
|
339 |
+
|
340 |
+
/// tip
|
341 |
+
The MLE for a Normal distribution gives us:
|
342 |
+
|
343 |
+
- $\hat{\mu}$ = sample mean
|
344 |
+
- $\hat{\sigma}^2$ = sample variance (using $n$ in the denominator, not $n-1$)
|
345 |
+
///
|
346 |
+
"""
|
347 |
+
)
|
348 |
+
return
|
349 |
+
|
350 |
+
|
351 |
+
@app.cell(hide_code=True)
|
352 |
+
def _(normal_controls):
|
353 |
+
normal_controls.center()
|
354 |
+
return
|
355 |
+
|
356 |
+
|
357 |
+
@app.cell(hide_code=True)
|
358 |
+
def _(
|
359 |
+
mo,
|
360 |
+
normal_generate_button,
|
361 |
+
normal_sample_size_slider,
|
362 |
+
np,
|
363 |
+
plt,
|
364 |
+
true_mu_slider,
|
365 |
+
true_sigma_slider,
|
366 |
+
):
|
367 |
+
# generate normal samples when button is clicked
|
368 |
+
normal_button_value = normal_generate_button.value
|
369 |
+
|
370 |
+
# get parameter values
|
371 |
+
normal_true_mu = true_mu_slider.value
|
372 |
+
normal_true_sigma = true_sigma_slider.value
|
373 |
+
normal_true_var = normal_true_sigma**2
|
374 |
+
normal_n = normal_sample_size_slider.value
|
375 |
+
|
376 |
+
# generate random data
|
377 |
+
normal_data = np.random.normal(normal_true_mu, normal_true_sigma, size=normal_n)
|
378 |
+
|
379 |
+
# calculate mle estimates
|
380 |
+
normal_mu_hat = np.mean(normal_data)
|
381 |
+
normal_sigma2_hat = np.mean((normal_data - normal_mu_hat)**2) # mle variance using n
|
382 |
+
normal_sigma_hat = np.sqrt(normal_sigma2_hat)
|
383 |
+
|
384 |
+
# create visualization
|
385 |
+
normal_fig, (normal_ax1, normal_ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
386 |
+
|
387 |
+
# plot histogram and density curves
|
388 |
+
normal_bins = np.linspace(min(normal_data) - 1, max(normal_data) + 1, 30)
|
389 |
+
normal_ax1.hist(normal_data, bins=normal_bins, density=True, alpha=0.6, color='lightblue', label='Data Histogram')
|
390 |
+
|
391 |
+
# plot range for density curves
|
392 |
+
normal_x = np.linspace(min(normal_data) - 2*normal_true_sigma, max(normal_data) + 2*normal_true_sigma, 1000)
|
393 |
+
|
394 |
+
# plot true and mle densities
|
395 |
+
normal_true_pdf = (1/(normal_true_sigma * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((normal_x - normal_true_mu)/normal_true_sigma)**2)
|
396 |
+
normal_ax1.plot(normal_x, normal_true_pdf, 'g-', linewidth=2, label=f'True: N({normal_true_mu:.2f}, {normal_true_var:.2f})')
|
397 |
+
|
398 |
+
normal_mle_pdf = (1/(normal_sigma_hat * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((normal_x - normal_mu_hat)/normal_sigma_hat)**2)
|
399 |
+
normal_ax1.plot(normal_x, normal_mle_pdf, 'r--', linewidth=2, label=f'MLE: N({normal_mu_hat:.2f}, {normal_sigma2_hat:.2f})')
|
400 |
+
|
401 |
+
normal_ax1.set_xlabel('x')
|
402 |
+
normal_ax1.set_ylabel('Density')
|
403 |
+
normal_ax1.set_title(f'Normal Distribution: {normal_n} samples')
|
404 |
+
normal_ax1.legend()
|
405 |
+
|
406 |
+
# create contour plot of log-likelihood
|
407 |
+
normal_mu_range = np.linspace(normal_mu_hat - 2, normal_mu_hat + 2, 100)
|
408 |
+
normal_sigma_range = np.linspace(max(0.1, normal_sigma_hat - 1), normal_sigma_hat + 1, 100)
|
409 |
+
|
410 |
+
normal_mu_grid, normal_sigma_grid = np.meshgrid(normal_mu_range, normal_sigma_range)
|
411 |
+
normal_ll_grid = np.zeros_like(normal_mu_grid)
|
412 |
+
|
413 |
+
# calculate log-likelihood for each grid point
|
414 |
+
for normal_i in range(normal_mu_grid.shape[0]):
|
415 |
+
for normal_j in range(normal_mu_grid.shape[1]):
|
416 |
+
normal_mu = normal_mu_grid[normal_i, normal_j]
|
417 |
+
normal_sigma = normal_sigma_grid[normal_i, normal_j]
|
418 |
+
normal_ll = -normal_n/2 * np.log(2*np.pi*normal_sigma**2) - np.sum((normal_data - normal_mu)**2)/(2*normal_sigma**2)
|
419 |
+
normal_ll_grid[normal_i, normal_j] = normal_ll
|
420 |
+
|
421 |
+
# plot log-likelihood contour
|
422 |
+
normal_contour = normal_ax2.contourf(normal_mu_grid, normal_sigma_grid, normal_ll_grid, levels=50, cmap='viridis')
|
423 |
+
normal_ax2.set_xlabel('μ (mean)')
|
424 |
+
normal_ax2.set_ylabel('σ (standard deviation)')
|
425 |
+
normal_ax2.set_title('Log-Likelihood Contour')
|
426 |
+
|
427 |
+
# mark mle and true params
|
428 |
+
normal_ax2.plot(normal_mu_hat, normal_sigma_hat, 'rx', markersize=10, label='MLE Estimate')
|
429 |
+
normal_ax2.plot(normal_true_mu, normal_true_sigma, 'g*', markersize=10, label='True Parameters')
|
430 |
+
normal_ax2.legend()
|
431 |
+
|
432 |
+
plt.colorbar(normal_contour, ax=normal_ax2, label='Log-Likelihood')
|
433 |
+
plt.tight_layout()
|
434 |
+
plt.gca()
|
435 |
+
|
436 |
+
# relevant markdown for the results
|
437 |
+
normal_explanation = mo.md(
|
438 |
+
f"""
|
439 |
+
### Normal MLE Results
|
440 |
+
|
441 |
+
**True parameters**: $\mu = {normal_true_mu:.3f}$, $\sigma^2 = {normal_true_var:.3f}$
|
442 |
+
**MLE estimates**: $\hat{{\mu}} = {normal_mu_hat:.3f}$, $\hat{{\sigma}}^2 = {normal_sigma2_hat:.3f}$
|
443 |
+
|
444 |
+
The left plot shows the data histogram with the true Normal distribution (green) and the MLE-estimated distribution (red dashed).
|
445 |
+
|
446 |
+
The right plot shows the log-likelihood function as a contour map in the $(\mu, \sigma)$ parameter space. The maximum likelihood estimates are marked with a red X, while the true parameters are marked with a green star.
|
447 |
+
|
448 |
+
/// note
|
449 |
+
Notice how the log-likelihood contour is more stretched along the σ axis than the μ axis. This indicates that we typically estimate the mean with greater precision than the standard deviation.
|
450 |
+
///
|
451 |
+
|
452 |
+
/// tip
|
453 |
+
Increase the sample size to see how the MLE estimates converge to the true parameter values!
|
454 |
+
///
|
455 |
+
"""
|
456 |
+
)
|
457 |
+
|
458 |
+
# plot and explanation together
|
459 |
+
mo.vstack([
|
460 |
+
normal_fig,
|
461 |
+
normal_explanation
|
462 |
+
])
|
463 |
+
return (
|
464 |
+
normal_ax1,
|
465 |
+
normal_ax2,
|
466 |
+
normal_bins,
|
467 |
+
normal_button_value,
|
468 |
+
normal_contour,
|
469 |
+
normal_data,
|
470 |
+
normal_explanation,
|
471 |
+
normal_fig,
|
472 |
+
normal_i,
|
473 |
+
normal_j,
|
474 |
+
normal_ll,
|
475 |
+
normal_ll_grid,
|
476 |
+
normal_mle_pdf,
|
477 |
+
normal_mu,
|
478 |
+
normal_mu_grid,
|
479 |
+
normal_mu_hat,
|
480 |
+
normal_mu_range,
|
481 |
+
normal_n,
|
482 |
+
normal_sigma,
|
483 |
+
normal_sigma2_hat,
|
484 |
+
normal_sigma_grid,
|
485 |
+
normal_sigma_hat,
|
486 |
+
normal_sigma_range,
|
487 |
+
normal_true_mu,
|
488 |
+
normal_true_pdf,
|
489 |
+
normal_true_sigma,
|
490 |
+
normal_true_var,
|
491 |
+
normal_x,
|
492 |
+
)
|
493 |
+
|
494 |
+
|
495 |
+
@app.cell(hide_code=True)
|
496 |
+
def _(mo):
|
497 |
+
mo.md(
|
498 |
+
r"""
|
499 |
+
## MLE for Linear Regression
|
500 |
+
|
501 |
+
Now let's look at a more practical example: using MLE to derive linear regression.
|
502 |
+
|
503 |
+
### The Model
|
504 |
+
|
505 |
+
Consider a model where:
|
506 |
+
- We have pairs of observations $(X_1, Y_1), (X_2, Y_2), \ldots, (X_n, Y_n)$
|
507 |
+
- The relationship between $X$ and $Y$ follows: $Y = \theta X + Z$
|
508 |
+
- $Z \sim N(0, \sigma^2)$ is random noise
|
509 |
+
- Our goal is to estimate the parameter $\theta$
|
510 |
+
|
511 |
+
This means that for a given $X_i$, the conditional distribution of $Y_i$ is:
|
512 |
+
|
513 |
+
$$Y_i | X_i \sim N(\theta X_i, \sigma^2)$$
|
514 |
+
|
515 |
+
### Deriving the MLE
|
516 |
+
|
517 |
+
Step 1: Write the likelihood function for each data point $(X_i, Y_i)$
|
518 |
+
$$f(Y_i | X_i, \theta) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right)$$
|
519 |
+
|
520 |
+
Step 2: Write the likelihood for all data
|
521 |
+
$$\begin{align*}
|
522 |
+
L(\theta) &= \prod_{i=1}^n f(Y_i, X_i | \theta) \\
|
523 |
+
&= \prod_{i=1}^n f(Y_i | X_i, \theta) \cdot f(X_i)
|
524 |
+
\end{align*}$$
|
525 |
+
|
526 |
+
Since $f(X_i)$ doesn't depend on $\theta$, we can simplify:
|
527 |
+
$$L(\theta) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right) \cdot f(X_i)$$
|
528 |
+
|
529 |
+
Step 3: Take the logarithm to get the log-likelihood
|
530 |
+
$$\begin{align*}
|
531 |
+
LL(\theta) &= \log \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right) \cdot f(X_i) \\
|
532 |
+
&= \sum_{i=1}^n \log\left[\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(Y_i - \theta X_i)^2}{2\sigma^2}\right)\right] + \sum_{i=1}^n \log f(X_i) \\
|
533 |
+
&= -\frac{n}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (Y_i - \theta X_i)^2 + \sum_{i=1}^n \log f(X_i)
|
534 |
+
\end{align*}$$
|
535 |
+
|
536 |
+
Step 4: Since we only care about maximizing with respect to $\theta$, we can drop terms that don't contain $\theta$:
|
537 |
+
$$\hat{\theta} = \underset{\theta}{\operatorname{argmax}} \left[ -\frac{1}{2\sigma^2} \sum_{i=1}^n (Y_i - \theta X_i)^2 \right]$$
|
538 |
+
|
539 |
+
This is equivalent to:
|
540 |
+
$$\hat{\theta} = \underset{\theta}{\operatorname{argmin}} \sum_{i=1}^n (Y_i - \theta X_i)^2$$
|
541 |
+
|
542 |
+
Step 5: Find the value of $\theta$ that minimizes the sum of squared errors by setting the derivative to zero:
|
543 |
+
$$\begin{align*}
|
544 |
+
\frac{d}{d\theta} \sum_{i=1}^n (Y_i - \theta X_i)^2 &= 0 \\
|
545 |
+
\sum_{i=1}^n -2X_i(Y_i - \theta X_i) &= 0 \\
|
546 |
+
\sum_{i=1}^n X_i Y_i - \theta X_i^2 &= 0 \\
|
547 |
+
\sum_{i=1}^n X_i Y_i &= \theta \sum_{i=1}^n X_i^2 \\
|
548 |
+
\hat{\theta} &= \frac{\sum_{i=1}^n X_i Y_i}{\sum_{i=1}^n X_i^2}
|
549 |
+
\end{align*}$$
|
550 |
+
|
551 |
+
/// tip
|
552 |
+
**Key Insight**: MLE for this simple linear model gives us the least squares estimator! This is an important connection between MLE and regression.
|
553 |
+
///
|
554 |
+
"""
|
555 |
+
)
|
556 |
+
return
|
557 |
+
|
558 |
+
|
559 |
+
@app.cell(hide_code=True)
|
560 |
+
def _(linear_controls):
|
561 |
+
linear_controls.center()
|
562 |
+
return
|
563 |
+
|
564 |
+
|
565 |
+
@app.cell(hide_code=True)
|
566 |
+
def _(
|
567 |
+
linear_generate_button,
|
568 |
+
linear_sample_size_slider,
|
569 |
+
mo,
|
570 |
+
noise_sigma_slider,
|
571 |
+
np,
|
572 |
+
plt,
|
573 |
+
true_theta_slider,
|
574 |
+
):
|
575 |
+
# linear model data calc when button is clicked
|
576 |
+
linear_button_value = linear_generate_button.value
|
577 |
+
|
578 |
+
# get parameter values
|
579 |
+
linear_true_theta = true_theta_slider.value
|
580 |
+
linear_noise_sigma = noise_sigma_slider.value
|
581 |
+
linear_n = linear_sample_size_slider.value
|
582 |
+
|
583 |
+
# generate x data (uniformly between -3 and 3)
|
584 |
+
linear_X = np.random.uniform(-3, 3, size=linear_n)
|
585 |
+
|
586 |
+
# generate y data according to the model y = θx + z
|
587 |
+
linear_Z = np.random.normal(0, linear_noise_sigma, size=linear_n)
|
588 |
+
linear_Y = linear_true_theta * linear_X + linear_Z
|
589 |
+
|
590 |
+
# calculate mle estimate
|
591 |
+
linear_theta_hat = np.sum(linear_X * linear_Y) / np.sum(linear_X**2)
|
592 |
+
|
593 |
+
# calculate sse for different theta values
|
594 |
+
linear_theta_range = np.linspace(linear_true_theta - 1.5, linear_true_theta + 1.5, 100)
|
595 |
+
linear_sse_values = np.zeros_like(linear_theta_range)
|
596 |
+
|
597 |
+
for linear_i, linear_theta in enumerate(linear_theta_range):
|
598 |
+
linear_y_pred = linear_theta * linear_X
|
599 |
+
linear_sse_values[linear_i] = np.sum((linear_Y - linear_y_pred)**2)
|
600 |
+
|
601 |
+
# convert sse to log-likelihood (ignoring constant terms)
|
602 |
+
linear_ll_values = -linear_sse_values / (2 * linear_noise_sigma**2)
|
603 |
+
|
604 |
+
# create visualization
|
605 |
+
linear_fig, (linear_ax1, linear_ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
606 |
+
|
607 |
+
# plot scatter plot with regression lines
|
608 |
+
linear_ax1.scatter(linear_X, linear_Y, color='blue', alpha=0.6, label='Data points')
|
609 |
+
|
610 |
+
# plot range for regression lines
|
611 |
+
linear_x_line = np.linspace(-3, 3, 100)
|
612 |
+
|
613 |
+
# plot true and mle regression lines
|
614 |
+
linear_ax1.plot(linear_x_line, linear_true_theta * linear_x_line, 'g-', linewidth=2, label=f'True: Y = {linear_true_theta:.2f}X')
|
615 |
+
linear_ax1.plot(linear_x_line, linear_theta_hat * linear_x_line, 'r--', linewidth=2, label=f'MLE: Y = {linear_theta_hat:.2f}X')
|
616 |
+
|
617 |
+
linear_ax1.set_xlabel('X')
|
618 |
+
linear_ax1.set_ylabel('Y')
|
619 |
+
linear_ax1.set_title(f'Linear Regression: {linear_n} data points')
|
620 |
+
linear_ax1.grid(True, alpha=0.3)
|
621 |
+
linear_ax1.legend()
|
622 |
+
|
623 |
+
# plot log-likelihood function
|
624 |
+
linear_ax2.plot(linear_theta_range, linear_ll_values, 'b-', linewidth=2)
|
625 |
+
linear_ax2.axvline(x=linear_theta_hat, color='r', linestyle='--', label=f'MLE: θ = {linear_theta_hat:.3f}')
|
626 |
+
linear_ax2.axvline(x=linear_true_theta, color='g', linestyle='--', label=f'True: θ = {linear_true_theta:.3f}')
|
627 |
+
linear_ax2.set_xlabel('θ (slope parameter)')
|
628 |
+
linear_ax2.set_ylabel('Log-Likelihood')
|
629 |
+
linear_ax2.set_title('Log-Likelihood Function')
|
630 |
+
linear_ax2.grid(True, alpha=0.3)
|
631 |
+
linear_ax2.legend()
|
632 |
+
|
633 |
+
plt.tight_layout()
|
634 |
+
plt.gca()
|
635 |
+
|
636 |
+
# relevant markdown to explain results
|
637 |
+
linear_explanation = mo.md(
|
638 |
+
f"""
|
639 |
+
### Linear Regression MLE Results
|
640 |
+
|
641 |
+
**True parameter**: $\\theta = {linear_true_theta:.3f}$
|
642 |
+
**MLE estimate**: $\\hat{{\\theta}} = {linear_theta_hat:.3f}$
|
643 |
+
|
644 |
+
The left plot shows the scatter plot of data points with the true regression line (green) and the MLE-estimated regression line (red dashed).
|
645 |
+
|
646 |
+
The right plot shows the log-likelihood function for different values of $\\theta$. The maximum likelihood estimate is marked with a red dashed line, and the true parameter is marked with a green dashed line.
|
647 |
+
|
648 |
+
/// note
|
649 |
+
The MLE estimate $\\hat{{\\theta}} = \\frac{{\\sum_{{i=1}}^n X_i Y_i}}{{\\sum_{{i=1}}^n X_i^2}}$ minimizes the sum of squared errors between the predicted and actual Y values.
|
650 |
+
///
|
651 |
+
|
652 |
+
/// tip
|
653 |
+
Try increasing the noise level to see how it affects the precision of the estimate!
|
654 |
+
///
|
655 |
+
"""
|
656 |
+
)
|
657 |
+
|
658 |
+
# show plot and explanation
|
659 |
+
mo.vstack([
|
660 |
+
linear_fig,
|
661 |
+
linear_explanation
|
662 |
+
])
|
663 |
+
return (
|
664 |
+
linear_X,
|
665 |
+
linear_Y,
|
666 |
+
linear_Z,
|
667 |
+
linear_ax1,
|
668 |
+
linear_ax2,
|
669 |
+
linear_button_value,
|
670 |
+
linear_explanation,
|
671 |
+
linear_fig,
|
672 |
+
linear_i,
|
673 |
+
linear_ll_values,
|
674 |
+
linear_n,
|
675 |
+
linear_noise_sigma,
|
676 |
+
linear_sse_values,
|
677 |
+
linear_theta,
|
678 |
+
linear_theta_hat,
|
679 |
+
linear_theta_range,
|
680 |
+
linear_true_theta,
|
681 |
+
linear_x_line,
|
682 |
+
linear_y_pred,
|
683 |
+
)
|
684 |
+
|
685 |
+
|
686 |
+
@app.cell(hide_code=True)
|
687 |
+
def _(mo):
|
688 |
+
mo.md(
|
689 |
+
r"""
|
690 |
+
## Interactive Concept: Likelihood vs. Probability
|
691 |
+
|
692 |
+
To better understand the distinction between likelihood and probability, let's create an interactive visualization. This concept is crucial for understanding why MLE works.
|
693 |
+
"""
|
694 |
+
)
|
695 |
+
return
|
696 |
+
|
697 |
+
|
698 |
+
@app.cell(hide_code=True)
|
699 |
+
def _(concept_controls):
|
700 |
+
concept_controls.center()
|
701 |
+
return
|
702 |
+
|
703 |
+
|
704 |
+
@app.cell(hide_code=True)
|
705 |
+
def _(concept_dist_type, mo, np, perspective_selector, plt, stats):
|
706 |
+
# current distribution type
|
707 |
+
concept_dist_type_value = concept_dist_type.value
|
708 |
+
|
709 |
+
# view mode from dropdown
|
710 |
+
concept_view_mode = "likelihood" if perspective_selector.value == "Likelihood Perspective" else "probability"
|
711 |
+
|
712 |
+
# visualization based on distribution type
|
713 |
+
concept_fig, concept_ax = plt.subplots(figsize=(10, 6))
|
714 |
+
|
715 |
+
if concept_dist_type_value == "Normal":
|
716 |
+
if concept_view_mode == "probability":
|
717 |
+
# probability perspective: fixed parameters, varying data
|
718 |
+
concept_mu = 0 # fixed parameter
|
719 |
+
concept_sigma = 1 # fixed parameter
|
720 |
+
|
721 |
+
# generate x values for the pdf
|
722 |
+
concept_x = np.linspace(-4, 4, 1000)
|
723 |
+
|
724 |
+
# plot pdf
|
725 |
+
concept_pdf = stats.norm.pdf(concept_x, concept_mu, concept_sigma)
|
726 |
+
concept_ax.plot(concept_x, concept_pdf, 'b-', linewidth=2, label='PDF: N(0, 1)')
|
727 |
+
|
728 |
+
# highlight specific data values
|
729 |
+
concept_data_points = [-2, -1, 0, 1, 2]
|
730 |
+
concept_colors = ['#FF9999', '#FFCC99', '#99FF99', '#99CCFF', '#CC99FF']
|
731 |
+
|
732 |
+
for concept_i, concept_data in enumerate(concept_data_points):
|
733 |
+
concept_prob = stats.norm.pdf(concept_data, concept_mu, concept_sigma)
|
734 |
+
concept_ax.plot([concept_data, concept_data], [0, concept_prob], concept_colors[concept_i], linewidth=2)
|
735 |
+
concept_ax.scatter(concept_data, concept_prob, color=concept_colors[concept_i], s=50,
|
736 |
+
label=f'P(X={concept_data}|μ=0,σ=1) = {concept_prob:.3f}')
|
737 |
+
|
738 |
+
concept_ax.set_xlabel('Data (x)')
|
739 |
+
concept_ax.set_ylabel('Probability Density')
|
740 |
+
concept_ax.set_title('Probability Perspective: Fixed Parameters (μ=0, σ=1), Different Data Points')
|
741 |
+
|
742 |
+
else: # likelihood perspective
|
743 |
+
# likelihood perspective: fixed data, varying parameters
|
744 |
+
concept_data_point = 1.5 # fixed observed data
|
745 |
+
|
746 |
+
# different possible parameter values (means)
|
747 |
+
concept_mus = [-1, 0, 1, 2, 3]
|
748 |
+
concept_sigma = 1
|
749 |
+
|
750 |
+
# generate x values for multiple pdfs
|
751 |
+
concept_x = np.linspace(-4, 6, 1000)
|
752 |
+
|
753 |
+
concept_colors = ['#FF9999', '#FFCC99', '#99FF99', '#99CCFF', '#CC99FF']
|
754 |
+
|
755 |
+
for concept_i, concept_mu in enumerate(concept_mus):
|
756 |
+
concept_pdf = stats.norm.pdf(concept_x, concept_mu, concept_sigma)
|
757 |
+
concept_ax.plot(concept_x, concept_pdf, color=concept_colors[concept_i], linewidth=2, alpha=0.7,
|
758 |
+
label=f'N({concept_mu}, 1)')
|
759 |
+
|
760 |
+
# mark the likelihood of the data point for this param
|
761 |
+
concept_likelihood = stats.norm.pdf(concept_data_point, concept_mu, concept_sigma)
|
762 |
+
concept_ax.plot([concept_data_point, concept_data_point], [0, concept_likelihood], concept_colors[concept_i], linewidth=2)
|
763 |
+
concept_ax.scatter(concept_data_point, concept_likelihood, color=concept_colors[concept_i], s=50,
|
764 |
+
label=f'L(μ={concept_mu}|X=1.5) = {concept_likelihood:.3f}')
|
765 |
+
|
766 |
+
# add vertical line at the observed data point
|
767 |
+
concept_ax.axvline(x=concept_data_point, color='black', linestyle='--',
|
768 |
+
label=f'Observed data: X=1.5')
|
769 |
+
|
770 |
+
concept_ax.set_xlabel('Data (x)')
|
771 |
+
concept_ax.set_ylabel('Probability Density / Likelihood')
|
772 |
+
concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=1.5), Different Parameter Values')
|
773 |
+
|
774 |
+
elif concept_dist_type_value == "Bernoulli":
|
775 |
+
if concept_view_mode == "probability":
|
776 |
+
# probability perspective: fixed parameter, two possible data values
|
777 |
+
concept_p = 0.3 # fixed parameter
|
778 |
+
|
779 |
+
# bar chart for p(x=0) and p(x=1)
|
780 |
+
concept_ax.bar([0, 1], [1-concept_p, concept_p], width=0.4, color=['#99CCFF', '#FF9999'],
|
781 |
+
alpha=0.7, label=f'PMF: Bernoulli({concept_p})')
|
782 |
+
|
783 |
+
# text showing probabilities
|
784 |
+
concept_ax.text(0, (1-concept_p)/2, f'P(X=0|p={concept_p}) = {1-concept_p:.3f}', ha='center', va='center', fontweight='bold')
|
785 |
+
concept_ax.text(1, concept_p/2, f'P(X=1|p={concept_p}) = {concept_p:.3f}', ha='center', va='center', fontweight='bold')
|
786 |
+
|
787 |
+
concept_ax.set_xlabel('Data (x)')
|
788 |
+
concept_ax.set_ylabel('Probability')
|
789 |
+
concept_ax.set_xticks([0, 1])
|
790 |
+
concept_ax.set_xticklabels(['X=0', 'X=1'])
|
791 |
+
concept_ax.set_ylim(0, 1)
|
792 |
+
concept_ax.set_title('Probability Perspective: Fixed Parameter (p=0.3), Different Data Values')
|
793 |
+
|
794 |
+
else: # likelihood perspective
|
795 |
+
# likelihood perspective: fixed data, varying parameter
|
796 |
+
concept_data_point = 1 # fixed observed data (success)
|
797 |
+
|
798 |
+
# different possible parameter values
|
799 |
+
concept_p_values = np.linspace(0.01, 0.99, 100)
|
800 |
+
|
801 |
+
# calculate likelihood for each p value
|
802 |
+
if concept_data_point == 1:
|
803 |
+
# for x=1, likelihood is p
|
804 |
+
concept_likelihood = concept_p_values
|
805 |
+
concept_ax.plot(concept_p_values, concept_likelihood, 'b-', linewidth=2,
|
806 |
+
label=f'L(p|X=1) = p')
|
807 |
+
|
808 |
+
# highlight specific values
|
809 |
+
concept_highlight_ps = [0.2, 0.5, 0.8]
|
810 |
+
concept_colors = ['#FF9999', '#99FF99', '#99CCFF']
|
811 |
+
|
812 |
+
for concept_i, concept_p in enumerate(concept_highlight_ps):
|
813 |
+
concept_ax.plot([concept_p, concept_p], [0, concept_p], concept_colors[concept_i], linewidth=2)
|
814 |
+
concept_ax.scatter(concept_p, concept_p, color=concept_colors[concept_i], s=50,
|
815 |
+
label=f'L(p={concept_p}|X=1) = {concept_p:.3f}')
|
816 |
+
|
817 |
+
concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=1), Different Parameter Values')
|
818 |
+
|
819 |
+
else: # x=0
|
820 |
+
# for x = 0, likelihood is (1-p)
|
821 |
+
concept_likelihood = 1 - concept_p_values
|
822 |
+
concept_ax.plot(concept_p_values, concept_likelihood, 'r-', linewidth=2,
|
823 |
+
label=f'L(p|X=0) = (1-p)')
|
824 |
+
|
825 |
+
# highlight some specific values
|
826 |
+
concept_highlight_ps = [0.2, 0.5, 0.8]
|
827 |
+
concept_colors = ['#FF9999', '#99FF99', '#99CCFF']
|
828 |
+
|
829 |
+
for concept_i, concept_p in enumerate(concept_highlight_ps):
|
830 |
+
concept_ax.plot([concept_p, concept_p], [0, 1-concept_p], concept_colors[concept_i], linewidth=2)
|
831 |
+
concept_ax.scatter(concept_p, 1-concept_p, color=concept_colors[concept_i], s=50,
|
832 |
+
label=f'L(p={concept_p}|X=0) = {1-concept_p:.3f}')
|
833 |
+
|
834 |
+
concept_ax.set_title('Likelihood Perspective: Fixed Data Point (X=0), Different Parameter Values')
|
835 |
+
|
836 |
+
concept_ax.set_xlabel('Parameter (p)')
|
837 |
+
concept_ax.set_ylabel('Likelihood')
|
838 |
+
concept_ax.set_xlim(0, 1)
|
839 |
+
concept_ax.set_ylim(0, 1)
|
840 |
+
|
841 |
+
elif concept_dist_type_value == "Poisson":
|
842 |
+
if concept_view_mode == "probability":
|
843 |
+
# probability perspective: fixed parameter, different data values
|
844 |
+
concept_lam = 2.5 # fixed parameter
|
845 |
+
|
846 |
+
# pmf for different x values plot
|
847 |
+
concept_x_values = np.arange(0, 10)
|
848 |
+
concept_pmf_values = stats.poisson.pmf(concept_x_values, concept_lam)
|
849 |
+
|
850 |
+
concept_ax.bar(concept_x_values, concept_pmf_values, width=0.4, color='#99CCFF',
|
851 |
+
alpha=0.7, label=f'PMF: Poisson({concept_lam})')
|
852 |
+
|
853 |
+
# highlight a few specific values
|
854 |
+
concept_highlight_xs = [1, 2, 3, 4]
|
855 |
+
concept_colors = ['#FF9999', '#99FF99', '#FFCC99', '#CC99FF']
|
856 |
+
|
857 |
+
for concept_i, concept_x in enumerate(concept_highlight_xs):
|
858 |
+
concept_prob = stats.poisson.pmf(concept_x, concept_lam)
|
859 |
+
concept_ax.scatter(concept_x, concept_prob, color=concept_colors[concept_i], s=50,
|
860 |
+
label=f'P(X={concept_x}|λ={concept_lam}) = {concept_prob:.3f}')
|
861 |
+
|
862 |
+
concept_ax.set_xlabel('Data (x)')
|
863 |
+
concept_ax.set_ylabel('Probability')
|
864 |
+
concept_ax.set_xticks(concept_x_values)
|
865 |
+
concept_ax.set_title('Probability Perspective: Fixed Parameter (λ=2.5), Different Data Values')
|
866 |
+
|
867 |
+
else: # likelihood perspective
|
868 |
+
# likelihood perspective: fixed data, varying parameter
|
869 |
+
concept_data_point = 4 # fixed observed data
|
870 |
+
|
871 |
+
# different possible param values
|
872 |
+
concept_lambda_values = np.linspace(0.1, 8, 100)
|
873 |
+
|
874 |
+
# calc likelihood for each lambda value
|
875 |
+
concept_likelihood = stats.poisson.pmf(concept_data_point, concept_lambda_values)
|
876 |
+
|
877 |
+
concept_ax.plot(concept_lambda_values, concept_likelihood, 'b-', linewidth=2,
|
878 |
+
label=f'L(λ|X={concept_data_point})')
|
879 |
+
|
880 |
+
# highlight some specific values
|
881 |
+
concept_highlight_lambdas = [1, 2, 4, 6]
|
882 |
+
concept_colors = ['#FF9999', '#99FF99', '#99CCFF', '#FFCC99']
|
883 |
+
|
884 |
+
for concept_i, concept_lam in enumerate(concept_highlight_lambdas):
|
885 |
+
concept_like_val = stats.poisson.pmf(concept_data_point, concept_lam)
|
886 |
+
concept_ax.plot([concept_lam, concept_lam], [0, concept_like_val], concept_colors[concept_i], linewidth=2)
|
887 |
+
concept_ax.scatter(concept_lam, concept_like_val, color=concept_colors[concept_i], s=50,
|
888 |
+
label=f'L(λ={concept_lam}|X={concept_data_point}) = {concept_like_val:.3f}')
|
889 |
+
|
890 |
+
concept_ax.set_xlabel('Parameter (λ)')
|
891 |
+
concept_ax.set_ylabel('Likelihood')
|
892 |
+
concept_ax.set_title(f'Likelihood Perspective: Fixed Data Point (X={concept_data_point}), Different Parameter Values')
|
893 |
+
|
894 |
+
concept_ax.legend(loc='best', fontsize=9)
|
895 |
+
concept_ax.grid(True, alpha=0.3)
|
896 |
+
plt.tight_layout()
|
897 |
+
plt.gca()
|
898 |
+
|
899 |
+
# relevant explanation based on view mode
|
900 |
+
if concept_view_mode == "probability":
|
901 |
+
concept_explanation = mo.md(
|
902 |
+
f"""
|
903 |
+
### Probability Perspective
|
904 |
+
|
905 |
+
In the **probability perspective**, the parameters of the distribution are **fixed and known**, and we calculate the probability (or density) for **different possible data values**.
|
906 |
+
|
907 |
+
For the {concept_dist_type_value} distribution, we've fixed the parameter{'s' if concept_dist_type_value == 'Normal' else ''} and shown the probability of observing different outcomes.
|
908 |
+
|
909 |
+
This is the typical perspective when:
|
910 |
+
|
911 |
+
- We know the true parameters of a distribution
|
912 |
+
- We want to calculate the probability of different outcomes
|
913 |
+
- We make predictions based on our model
|
914 |
+
|
915 |
+
**Mathematical notation**: $P(X = x | \theta)$
|
916 |
+
"""
|
917 |
+
)
|
918 |
+
else: # likelihood perspective
|
919 |
+
concept_explanation = mo.md(
|
920 |
+
f"""
|
921 |
+
### Likelihood Perspective
|
922 |
+
|
923 |
+
In the **likelihood perspective**, the observed data is **fixed and known**, and we calculate how likely different parameter values are to have generated that data.
|
924 |
+
|
925 |
+
For the {concept_dist_type_value} distribution, we've fixed the observed data point{'s' if concept_dist_type_value == 'Normal' else ''} and shown the likelihood of different parameter values.
|
926 |
+
|
927 |
+
This is the perspective used in MLE:
|
928 |
+
|
929 |
+
- We have observed data
|
930 |
+
- We don't know the true parameters
|
931 |
+
- We want to find parameters that best explain our observations
|
932 |
+
|
933 |
+
**Mathematical notation**: $L(\theta | X = x)$
|
934 |
+
|
935 |
+
/// tip
|
936 |
+
The value of $\\theta$ that maximizes this likelihood function is the MLE estimate $\\hat{{\\theta}}$!
|
937 |
+
///
|
938 |
+
"""
|
939 |
+
)
|
940 |
+
|
941 |
+
# Display plot and explanation together
|
942 |
+
mo.vstack([
|
943 |
+
concept_fig,
|
944 |
+
concept_explanation
|
945 |
+
])
|
946 |
+
return (
|
947 |
+
concept_ax,
|
948 |
+
concept_colors,
|
949 |
+
concept_data,
|
950 |
+
concept_data_point,
|
951 |
+
concept_data_points,
|
952 |
+
concept_dist_type_value,
|
953 |
+
concept_explanation,
|
954 |
+
concept_fig,
|
955 |
+
concept_highlight_lambdas,
|
956 |
+
concept_highlight_ps,
|
957 |
+
concept_highlight_xs,
|
958 |
+
concept_i,
|
959 |
+
concept_lam,
|
960 |
+
concept_lambda_values,
|
961 |
+
concept_like_val,
|
962 |
+
concept_likelihood,
|
963 |
+
concept_mu,
|
964 |
+
concept_mus,
|
965 |
+
concept_p,
|
966 |
+
concept_p_values,
|
967 |
+
concept_pdf,
|
968 |
+
concept_pmf_values,
|
969 |
+
concept_prob,
|
970 |
+
concept_sigma,
|
971 |
+
concept_view_mode,
|
972 |
+
concept_x,
|
973 |
+
concept_x_values,
|
974 |
+
)
|
975 |
+
|
976 |
+
|
977 |
+
@app.cell(hide_code=True)
|
978 |
+
def _(mo):
|
979 |
+
mo.md(
|
980 |
+
r"""
|
981 |
+
## 🤔 Test Your Understanding
|
982 |
+
|
983 |
+
Which of the following statements about Maximum Likelihood Estimation are correct? Click each statement to check your answer.
|
984 |
+
|
985 |
+
/// details | Probability and likelihood use the same formulas, but probability measures the chance of data given parameters, while likelihood measures how likely parameters are given data.
|
986 |
+
✅ **Correct!**
|
987 |
+
|
988 |
+
Probability measures how likely it is to observe particular data when we know the parameters. Likelihood measures how likely particular parameter values are, given observed data.
|
989 |
+
|
990 |
+
Mathematically, probability is $P(X=x|\theta)$ while likelihood is $L(\theta|X=x)$. They use the same formula, but with different perspectives on what's fixed and what varies.
|
991 |
+
///
|
992 |
+
|
993 |
+
/// details | We use log-likelihood instead of likelihood because it's mathematically simpler and numerically more stable.
|
994 |
+
✅ **Correct!**
|
995 |
+
|
996 |
+
We work with log-likelihood for several reasons:
|
997 |
+
1. It converts products into sums, which is easier to work with mathematically
|
998 |
+
2. It avoids numerical underflow when multiplying many small probabilities
|
999 |
+
3. Logarithm is a monotonically increasing function, so the maximum of the likelihood occurs at the same parameter values as the maximum of the log-likelihood
|
1000 |
+
///
|
1001 |
+
|
1002 |
+
/// details | For a Bernoulli distribution, the MLE for parameter p is the sample mean of the observations.
|
1003 |
+
✅ **Correct!**
|
1004 |
+
|
1005 |
+
For a Bernoulli distribution with parameter $p$, given $n$ independent samples $X_1, X_2, \ldots, X_n$, the MLE estimator is:
|
1006 |
+
|
1007 |
+
$$\hat{p} = \frac{\sum_{i=1}^n X_i}{n}$$
|
1008 |
+
|
1009 |
+
This is simply the sample mean, or the proportion of successes (1s) in the data.
|
1010 |
+
///
|
1011 |
+
|
1012 |
+
/// details | For a Normal distribution, MLE gives unbiased estimates for both mean and variance parameters.
|
1013 |
+
❌ **Incorrect.**
|
1014 |
+
|
1015 |
+
While the MLE for the mean ($\hat{\mu} = \frac{1}{n}\sum_{i=1}^n X_i$) is unbiased, the MLE for variance:
|
1016 |
+
|
1017 |
+
$$\hat{\sigma}^2 = \frac{1}{n}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
|
1018 |
+
|
1019 |
+
is a biased estimator. It uses $n$ in the denominator rather than $n-1$ used in the unbiased estimator.
|
1020 |
+
///
|
1021 |
+
|
1022 |
+
/// details | MLE estimators are always unbiased regardless of the distribution.
|
1023 |
+
❌ **Incorrect.**
|
1024 |
+
|
1025 |
+
MLE is not always unbiased, though it often is asymptotically unbiased (meaning the bias approaches zero as the sample size increases).
|
1026 |
+
|
1027 |
+
A notable example is the MLE estimator for the variance of a Normal distribution:
|
1028 |
+
$$\hat{\sigma}^2 = \frac{1}{n}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
|
1029 |
+
|
1030 |
+
This estimator is biased, which is why we often use the unbiased estimator:
|
1031 |
+
$$s^2 = \frac{1}{n-1}\sum_{i=1}^n (X_i - \hat{\mu})^2$$
|
1032 |
+
|
1033 |
+
Despite occasional bias, MLE estimators have many desirable properties, including consistency and asymptotic efficiency.
|
1034 |
+
///
|
1035 |
+
"""
|
1036 |
+
)
|
1037 |
+
return
|
1038 |
+
|
1039 |
+
|
1040 |
+
@app.cell(hide_code=True)
|
1041 |
+
def _(mo):
|
1042 |
+
mo.md(
|
1043 |
+
r"""
|
1044 |
+
## Summary
|
1045 |
+
|
1046 |
+
Maximum Likelihood Estimation really is one of those elegant ideas that sits at the core of modern statistics. When you get down to it, MLE is just about finding the most plausible explanation for the data we've observed. It's like being a detective - you have some clues (your data), and you're trying to piece together the most likely story (your parameters) that explains them.
|
1047 |
+
|
1048 |
+
We've seen how this works with different distributions. For the Bernoulli, it simply gives us the sample proportion. For the Normal, it gives us the sample mean and a slightly biased estimate of variance. And for linear regression, it provides a mathematical justification for the least squares method that everyone learns in basic stats classes.
|
1049 |
+
|
1050 |
+
What makes MLE so useful in practice is that it tends to give us estimates with good properties. As you collect more data, the estimates generally get closer to the true values (consistency) and do so efficiently. That's why MLE is everywhere in statistics and machine learning - from simple regression models to complex neural networks.
|
1051 |
+
|
1052 |
+
The most important takeaway? Next time you're fitting a model to data, remember that you're not just following a recipe - you're finding the parameters that make your observed data most likely to have occurred. That's the essence of statistical inference.
|
1053 |
+
"""
|
1054 |
+
)
|
1055 |
+
return
|
1056 |
+
|
1057 |
+
|
1058 |
+
@app.cell(hide_code=True)
|
1059 |
+
def _(mo):
|
1060 |
+
mo.md(
|
1061 |
+
r"""
|
1062 |
+
## Further Reading
|
1063 |
+
|
1064 |
+
If you're curious to dive deeper into this topic, check out "Statistical Inference" by Casella and Berger - it's the classic text that many statisticians learned from. For a more machine learning angle, Bishop's "Pattern Recognition and Machine Learning" shows how MLE connects to more advanced topics like EM algorithms and Bayesian methods.
|
1065 |
+
|
1066 |
+
Beyond the basics we've covered, you might explore Bayesian estimation (which incorporates prior knowledge), Fisher Information (which tells us how precisely we can estimate parameters), or the EM algorithm (for when we have missing data or latent variables). Each of these builds on the foundation of likelihood that we've established here.
|
1067 |
+
"""
|
1068 |
+
)
|
1069 |
+
return
|
1070 |
+
|
1071 |
+
|
1072 |
+
@app.cell(hide_code=True)
|
1073 |
+
def _(mo):
|
1074 |
+
mo.md(r"""## Appendix (helper functions and imports)""")
|
1075 |
+
return
|
1076 |
+
|
1077 |
+
|
1078 |
+
@app.cell
|
1079 |
+
def _():
|
1080 |
+
import marimo as mo
|
1081 |
+
return (mo,)
|
1082 |
+
|
1083 |
+
|
1084 |
+
@app.cell
|
1085 |
+
def _():
|
1086 |
+
import numpy as np
|
1087 |
+
import matplotlib.pyplot as plt
|
1088 |
+
from scipy import stats
|
1089 |
+
import plotly.graph_objects as go
|
1090 |
+
import polars as pl
|
1091 |
+
from matplotlib import cm
|
1092 |
+
|
1093 |
+
# Set a consistent random seed for reproducibility
|
1094 |
+
np.random.seed(42)
|
1095 |
+
|
1096 |
+
# Set a nice style for matplotlib
|
1097 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
1098 |
+
return cm, go, np, pl, plt, stats
|
1099 |
+
|
1100 |
+
|
1101 |
+
@app.cell(hide_code=True)
|
1102 |
+
def _(mo):
|
1103 |
+
# Create interactive elements
|
1104 |
+
true_p_slider = mo.ui.slider(
|
1105 |
+
start =0.01,
|
1106 |
+
stop =0.99,
|
1107 |
+
value=0.3,
|
1108 |
+
step=0.01,
|
1109 |
+
label="True probability (p)"
|
1110 |
+
)
|
1111 |
+
|
1112 |
+
sample_size_slider = mo.ui.slider(
|
1113 |
+
start =10,
|
1114 |
+
stop =1000,
|
1115 |
+
value=100,
|
1116 |
+
step=10,
|
1117 |
+
label="Sample size (n)"
|
1118 |
+
)
|
1119 |
+
|
1120 |
+
generate_button = mo.ui.button(label="Generate New Sample", kind="success")
|
1121 |
+
|
1122 |
+
controls = mo.vstack([
|
1123 |
+
mo.vstack([true_p_slider, sample_size_slider]),
|
1124 |
+
generate_button
|
1125 |
+
], justify="space-between")
|
1126 |
+
return controls, generate_button, sample_size_slider, true_p_slider
|
1127 |
+
|
1128 |
+
|
1129 |
+
@app.cell(hide_code=True)
|
1130 |
+
def _(mo):
|
1131 |
+
# Create interactive elements for Normal distribution
|
1132 |
+
true_mu_slider = mo.ui.slider(
|
1133 |
+
start =-5,
|
1134 |
+
stop =5,
|
1135 |
+
value=0,
|
1136 |
+
step=0.1,
|
1137 |
+
label="True mean (μ)"
|
1138 |
+
)
|
1139 |
+
|
1140 |
+
true_sigma_slider = mo.ui.slider(
|
1141 |
+
start =0.5,
|
1142 |
+
stop =3,
|
1143 |
+
value=1,
|
1144 |
+
step=0.1,
|
1145 |
+
label="True standard deviation (σ)"
|
1146 |
+
)
|
1147 |
+
|
1148 |
+
normal_sample_size_slider = mo.ui.slider(
|
1149 |
+
start =10,
|
1150 |
+
stop =500,
|
1151 |
+
value=50,
|
1152 |
+
step=10,
|
1153 |
+
label="Sample size (n)"
|
1154 |
+
)
|
1155 |
+
|
1156 |
+
normal_generate_button = mo.ui.button(label="Generate New Sample", kind="warn")
|
1157 |
+
|
1158 |
+
normal_controls = mo.hstack([
|
1159 |
+
mo.vstack([true_mu_slider, true_sigma_slider, normal_sample_size_slider]),
|
1160 |
+
normal_generate_button
|
1161 |
+
], justify="space-between")
|
1162 |
+
return (
|
1163 |
+
normal_controls,
|
1164 |
+
normal_generate_button,
|
1165 |
+
normal_sample_size_slider,
|
1166 |
+
true_mu_slider,
|
1167 |
+
true_sigma_slider,
|
1168 |
+
)
|
1169 |
+
|
1170 |
+
|
1171 |
+
@app.cell(hide_code=True)
|
1172 |
+
def _(mo):
|
1173 |
+
# Create interactive elements for linear regression
|
1174 |
+
true_theta_slider = mo.ui.slider(
|
1175 |
+
start =-2,
|
1176 |
+
stop =2,
|
1177 |
+
value=0.5,
|
1178 |
+
step=0.1,
|
1179 |
+
label="True slope (θ)"
|
1180 |
+
)
|
1181 |
+
|
1182 |
+
noise_sigma_slider = mo.ui.slider(
|
1183 |
+
start =0.1,
|
1184 |
+
stop =2,
|
1185 |
+
value=0.5,
|
1186 |
+
step=0.1,
|
1187 |
+
label="Noise level (σ)"
|
1188 |
+
)
|
1189 |
+
|
1190 |
+
linear_sample_size_slider = mo.ui.slider(
|
1191 |
+
start =10,
|
1192 |
+
stop =200,
|
1193 |
+
value=50,
|
1194 |
+
step=10,
|
1195 |
+
label="Sample size (n)"
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
linear_generate_button = mo.ui.button(label="Generate New Sample", kind="warn")
|
1199 |
+
|
1200 |
+
linear_controls = mo.hstack([
|
1201 |
+
mo.vstack([true_theta_slider, noise_sigma_slider, linear_sample_size_slider]),
|
1202 |
+
linear_generate_button
|
1203 |
+
], justify="space-between")
|
1204 |
+
return (
|
1205 |
+
linear_controls,
|
1206 |
+
linear_generate_button,
|
1207 |
+
linear_sample_size_slider,
|
1208 |
+
noise_sigma_slider,
|
1209 |
+
true_theta_slider,
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
|
1213 |
+
@app.cell(hide_code=True)
|
1214 |
+
def _(mo):
|
1215 |
+
# Interactive elements for likelihood vs probability demo
|
1216 |
+
concept_dist_type = mo.ui.dropdown(
|
1217 |
+
options=["Normal", "Bernoulli", "Poisson"],
|
1218 |
+
value="Normal",
|
1219 |
+
label="Distribution"
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
# Replace buttons with a simple dropdown selector
|
1223 |
+
perspective_selector = mo.ui.dropdown(
|
1224 |
+
options=["Probability Perspective", "Likelihood Perspective"],
|
1225 |
+
value="Probability Perspective",
|
1226 |
+
label="View"
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
concept_controls = mo.vstack([
|
1230 |
+
mo.hstack([concept_dist_type, perspective_selector])
|
1231 |
+
])
|
1232 |
+
return concept_controls, concept_dist_type, perspective_selector
|
1233 |
+
|
1234 |
+
|
1235 |
+
if __name__ == "__main__":
|
1236 |
+
app.run()
|