#LogisticRegression

import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Define the Logistic Regression Model
class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)
    
    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# Generate synthetic data
np.random.seed(0)
torch.manual_seed(0)
n_samples = 100
X = np.random.rand(n_samples, 1) * 10  # Random hours between 0 and 10
y = (X > 5).astype(int).flatten()  # Pass if study hours > 5, otherwise fail

# Convert to torch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

# Streamlit interface
st.title('Logistic Regression with PyTorch')

# User inputs
num_epochs = st.number_input('Number of Epochs', min_value=100, max_value=5000, step=100, value=1000)
learning_rate = st.number_input('Learning Rate', min_value=0.0001, max_value=0.1, step=0.0001, format="%.4f", value=0.01)
test_hours = st.text_input('Test Study Hours (comma separated)', '4.0, 6.0, 9.0')

# Initialize the model
model = LogisticRegressionModel()

# Binary Cross Entropy Loss
criterion = nn.BCELoss()

# Stochastic Gradient Descent Optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Training the model
loss_values = []
for epoch in range(num_epochs):
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor.view(-1, 1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_values.append(loss.item())

# Plot the loss curve
fig, ax = plt.subplots()
ax.plot(range(num_epochs), loss_values)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Loss Curve')
st.pyplot(fig)

# Evaluation
model.eval()
test_hours = [float(hour.strip()) for hour in test_hours.split(',')]
test_tensor = torch.tensor(test_hours, dtype=torch.float32).view(-1, 1)
predictions = model(test_tensor).detach().numpy()

# Display predictions
st.write('## Predictions')
for i, test_hour in enumerate(test_hours):
    st.write(f"Study hours: {test_hour}, Predicted pass probability: {predictions[i][0]:.4f}")