Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense | |
# Title | |
st.title("Neural Network Line Fitting") | |
# Sidebar sliders for generating synthetic data | |
st.sidebar.header("Synthetic Data Controls") | |
true_w = st.sidebar.slider('True W (slope)', min_value=-10.0, max_value=10.0, value=2.0, step=0.1) | |
true_b = st.sidebar.slider('True B (intercept)', min_value=-10.0, max_value=10.0, value=1.0, step=0.1) | |
num_points = st.sidebar.slider('Number of data points', min_value=10, max_value=1000, value=100, step=10) | |
# Generate synthetic data | |
np.random.seed(0) | |
x_data = np.random.uniform(-100, 100, num_points) | |
noise = np.random.normal(0, 10, num_points) | |
y_data = true_w * x_data + true_b + noise | |
# Neural network model | |
model = Sequential([ | |
Dense(1, input_dim=1) | |
]) | |
model.compile(optimizer='adam', loss='mean_squared_error') | |
# Train the model | |
model.fit(x_data, y_data, epochs=100, verbose=0) | |
# Get the trained parameters | |
trained_w = model.layers[0].get_weights()[0][0][0] | |
trained_b = model.layers[0].get_weights()[1][0] | |
# Make predictions | |
x_pred = np.linspace(-100, 100, 200) | |
y_pred = model.predict(x_pred) | |
# Plot the results | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
# Plot for the x-axis (bottom line) | |
ax.hlines(-1, -100, 100, color='blue', linestyle='--') # X-axis | |
# Plot for the y-axis (top line) | |
ax.hlines(1, -100, 100, color='blue', linestyle='--') # Y-axis | |
# Plot the synthetic data points | |
ax.scatter(x_data, y_data, color='gray', alpha=0.5, label='Data points') | |
# Plot the prediction line | |
ax.plot(x_pred, y_pred, color='red', label=f'Fitted line: y = {trained_w:.2f}x + {trained_b:.2f}') | |
# Update the layout | |
ax.set_xlim(-100, 100) | |
ax.set_ylim(-2, 2) | |
ax.set_xlabel('X-axis and Y-axis') | |
ax.set_yticks([]) # Hide y-axis ticks | |
ax.set_title('Neural Network Line Fitting') | |
ax.legend() | |
ax.grid(True) | |
# Display the plot in Streamlit | |
st.pyplot(fig) | |
# Display the trained parameters | |
st.write(f'Trained parameters: w = {trained_w:.2f}, b = {trained_b:.2f}') | |