TensorFlowClass / pages /42_regression.py
eaglelandsonce's picture
Update pages/42_regression.py
9f3e74e verified
raw
history blame
2.1 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
# Title
st.title("Neural Network Line Fitting")
# Sidebar sliders for generating synthetic data
st.sidebar.header("Synthetic Data Controls")
true_w = st.sidebar.slider('True W (slope)', min_value=-10.0, max_value=10.0, value=2.0, step=0.1)
true_b = st.sidebar.slider('True B (intercept)', min_value=-10.0, max_value=10.0, value=1.0, step=0.1)
num_points = st.sidebar.slider('Number of data points', min_value=10, max_value=1000, value=100, step=10)
# Generate synthetic data
np.random.seed(0)
x_data = np.random.uniform(-100, 100, num_points)
noise = np.random.normal(0, 10, num_points)
y_data = true_w * x_data + true_b + noise
# Neural network model
model = Sequential([
Dense(1, input_dim=1)
])
model.compile(optimizer='adam', loss='mean_squared_error')
# Train the model
model.fit(x_data, y_data, epochs=100, verbose=0)
# Get the trained parameters
trained_w = model.layers[0].get_weights()[0][0][0]
trained_b = model.layers[0].get_weights()[1][0]
# Make predictions
x_pred = np.linspace(-100, 100, 200)
y_pred = model.predict(x_pred)
# Plot the results
fig, ax = plt.subplots(figsize=(10, 5))
# Plot for the x-axis (bottom line)
ax.hlines(-1, -100, 100, color='blue', linestyle='--') # X-axis
# Plot for the y-axis (top line)
ax.hlines(1, -100, 100, color='blue', linestyle='--') # Y-axis
# Plot the synthetic data points
ax.scatter(x_data, y_data, color='gray', alpha=0.5, label='Data points')
# Plot the prediction line
ax.plot(x_pred, y_pred, color='red', label=f'Fitted line: y = {trained_w:.2f}x + {trained_b:.2f}')
# Update the layout
ax.set_xlim(-100, 100)
ax.set_ylim(-2, 2)
ax.set_xlabel('X-axis and Y-axis')
ax.set_yticks([]) # Hide y-axis ticks
ax.set_title('Neural Network Line Fitting')
ax.legend()
ax.grid(True)
# Display the plot in Streamlit
st.pyplot(fig)
# Display the trained parameters
st.write(f'Trained parameters: w = {trained_w:.2f}, b = {trained_b:.2f}')