eaglelandsonce commited on
Commit
8bd1220
·
verified ·
1 Parent(s): 270150e

Update pages/2_LinearRegression.py

Browse files
Files changed (1) hide show
  1. pages/2_LinearRegression.py +41 -74
pages/2_LinearRegression.py CHANGED
@@ -1,82 +1,49 @@
1
  import streamlit as st
 
 
2
  import torch
3
  import torch.nn as nn
4
- import torch.optim as optim
5
- import matplotlib.pyplot as plt
6
 
7
- # Define the dataset
8
- def generate_data(n_samples):
9
- torch.manual_seed(42)
10
- X = torch.randn(n_samples, 1) * 10
11
- y = 2 * X + 3 + torch.randn(n_samples, 1) * 3
12
- return X, y
13
 
14
- # Define the linear regression model
15
- class LinearRegressionModel(nn.Module):
16
- def __init__(self):
17
- super(LinearRegressionModel, self).__init__()
18
- self.linear = nn.Linear(1, 1)
19
-
20
  def forward(self, x):
21
  return self.linear(x)
22
 
23
- # Train the model
24
- def train_model(X, y, lr, epochs):
25
- model = LinearRegressionModel()
26
- criterion = nn.MSELoss()
27
- optimizer = optim.SGD(model.parameters(), lr=lr)
28
-
29
- for epoch in range(epochs):
30
- model.train()
31
- optimizer.zero_grad()
32
- outputs = model(X)
33
- loss = criterion(outputs, y)
34
- loss.backward()
35
- optimizer.step()
36
-
37
- return model
38
-
39
- # Plot the results
40
- def plot_results(X, y, model):
41
- plt.scatter(X.numpy(), y.numpy(), label='Original data')
42
- plt.plot(X.numpy(), model(X).detach().numpy(), label='Fitted line', color='r')
43
- plt.legend()
44
- plt.xlabel('X')
45
- plt.ylabel('y')
46
- st.pyplot(plt.gcf())
47
-
48
- # Streamlit interface
49
- st.title('Simple Linear Regression with PyTorch')
50
- n_samples = st.slider('Number of samples', 20, 100, 50)
51
- learning_rate = st.slider('Learning rate', 0.001, 0.1, 0.01)
52
- epochs = st.slider('Number of epochs', 100, 1000, 500)
53
-
54
- X, y = generate_data(n_samples)
55
- model = train_model(X, y, learning_rate, epochs)
56
-
57
- st.subheader('Training Data')
58
- plot_results(X, y, model)
59
-
60
- st.subheader('Model Parameters')
61
- st.write(f'Weight: {model.linear.weight.item()}')
62
- st.write(f'Bias: {model.linear.bias.item()}')
63
-
64
- st.subheader('Loss Curve')
65
- losses = []
66
- model = LinearRegressionModel()
67
- criterion = nn.MSELoss()
68
- optimizer = optim.SGD(model.parameters(), lr=learning_rate)
69
- for epoch in range(epochs):
70
- model.train()
71
- optimizer.zero_grad()
72
- outputs = model(X)
73
- loss = criterion(outputs, y)
74
- loss.backward()
75
- optimizer.step()
76
- losses.append(loss.item())
77
-
78
- plt.figure()
79
- plt.plot(range(epochs), losses)
80
- plt.xlabel('Epoch')
81
- plt.ylabel('Loss')
82
- st.pyplot(plt.gcf())
 
1
  import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
  import torch
5
  import torch.nn as nn
 
 
6
 
7
+ # Set a seed for reproducibility
8
+ torch.manual_seed(59)
 
 
 
 
9
 
10
+ # Define the Linear Model
11
+ class LinearModel(nn.Module):
12
+ def __init__(self, in_features, out_features):
13
+ super(LinearModel, self).__init__()
14
+ self.linear = nn.Linear(in_features, out_features)
15
+
16
  def forward(self, x):
17
  return self.linear(x)
18
 
19
+ # Instantiate the model
20
+ model = LinearModel(1, 1)
21
+
22
+ # Print model weight and bias
23
+ print(f'Model weight: {model.linear.weight.item()}')
24
+ print(f'Model bias: {model.linear.bias.item()}')
25
+
26
+ # Streamlit app title
27
+ st.title('Interactive Scatter Plot with Noise and Number of Data Points')
28
+
29
+ # Sidebar sliders for noise and number of data points
30
+ noise_level = st.sidebar.slider('Noise Level', 0.0, 1.0, 0.1, step=0.01)
31
+ num_points = st.sidebar.slider('Number of Data Points', 10, 100, 50, step=5)
32
+
33
+ # Generate data
34
+ np.random.seed(0)
35
+ x = np.linspace(0, 10, num_points).reshape(-1, 1).astype(np.float32)
36
+ with torch.no_grad():
37
+ x_tensor = torch.tensor(x)
38
+ y_tensor = model(x_tensor)
39
+ y = y_tensor.numpy().flatten() + noise_level * np.random.randn(num_points)
40
+
41
+ # Create scatter plot
42
+ fig, ax = plt.subplots()
43
+ ax.scatter(x, y, alpha=0.6)
44
+ ax.set_title('Scatter Plot with Noise and Number of Data Points')
45
+ ax.set_xlabel('X-axis')
46
+ ax.set_ylabel('Y-axis')
47
+
48
+ # Display plot in Streamlit
49
+ st.pyplot(fig)