from flask import Flask
from flask_restx import Api, Resource, fields
from werkzeug.datastructures import FileStorage
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import joblib
import streamlit as st
import pandas as pd
import requests
import threading
import json

app = Flask(__name__)
api = Api(app, version='1.0', title='Car Depreciation Model API',
          description='API for creating and testing car depreciation models')

model_ns = api.namespace('model', description='Model operations')
predict_ns = api.namespace('predict', description='Prediction operations')

# Define the expected input for file upload
upload_parser = api.parser()
upload_parser.add_argument('file', location='files', type=FileStorage, required=True)

# Define the expected input for prediction
input_model = api.model('PredictionInput', {
    'Car_Model': fields.String(required=True, description='Car model'),
    'Car_Year': fields.Integer(required=True, description='Year of the car'),
    'Assessment_Year': fields.Integer(required=True, description='Assessment year'),
    'Starting_Asset_Value': fields.Float(required=True, description='Starting asset value'),
    'Book_Residual_Value': fields.Float(required=True, description='Book residual value'),
    'Market_Value': fields.Float(required=True, description='Market value')
})

# Global variable to store the model
global_model = None

@model_ns.route('/create')
@api.expect(upload_parser)
class ModelCreation(Resource):
    @api.doc(description='Create a new model from CSV data')
    @api.response(200, 'Model created successfully')
    @api.response(400, 'Invalid input')
    def post(self):
        global global_model
        args = upload_parser.parse_args()
        uploaded_file = args['file']
        
        if uploaded_file and uploaded_file.filename.endswith('.csv'):
            # Read the CSV file
            data = pd.read_csv(uploaded_file)
            
            # Prepare features and target
            X = data.drop('Depreciation_Percent', axis=1)
            y = data['Depreciation_Percent']
            
            # Split the data into training and testing sets
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
            
            # Create preprocessing steps
            numeric_features = ['Car_Year', 'Assessment_Year', 'Starting_Asset_Value', 'Book_Residual_Value', 'Market_Value']
            categorical_features = ['Car_Model']
            
            numeric_transformer = SimpleImputer(strategy='median')
            categorical_transformer = Pipeline(steps=[
                ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
                ('onehot', OneHotEncoder(handle_unknown='ignore'))
            ])
            
            preprocessor = ColumnTransformer(
                transformers=[
                    ('num', numeric_transformer, numeric_features),
                    ('cat', categorical_transformer, categorical_features)
                ])
            
            # Create a pipeline with preprocessor and model
            model = Pipeline(steps=[('preprocessor', preprocessor),
                                    ('regressor', LinearRegression())])
            
            # Fit the model
            model.fit(X_train, y_train)
            
            # Make predictions on the test set
            y_pred = model.predict(X_test)
            
            # Evaluate the model
            mse = mean_squared_error(y_test, y_pred)
            r2 = r2_score(y_test, y_pred)
            
            # Save the model
            joblib.dump(model, 'output/car_depreciation_model.joblib')
            global_model = model
            
            return {
                'message': 'Model created and saved successfully',
                'mse': float(mse),
                'r2': float(r2)
            }, 200
        
        return {'error': 'Invalid file format'}, 400

@predict_ns.route('/')
class Prediction(Resource):
    @api.expect(input_model)
    @api.doc(description='Predict car depreciation')
    @api.response(200, 'Successful prediction')
    @api.response(400, 'Invalid input')
    @api.response(404, 'Model not found')
    def post(self):
        global global_model
        try:
            if global_model is None:
                try:
                    global_model = joblib.load('output/car_depreciation_model.joblib')
                except FileNotFoundError:
                    return {'error': 'Model not found. Please create a model first.'}, 404
            
            # Get JSON data from the request
            data = api.payload
            
            # Convert JSON to DataFrame
            new_data_df = pd.DataFrame([data])
            
            # Make prediction
            prediction = global_model.predict(new_data_df)
            
            return {
                'predicted_depreciation': float(prediction[0])
            }, 200
        
        except Exception as e:
            return {'error': str(e)}, 400


API_URL = "http://localhost:5000"
st.title('Car Depreciation Predictor')

# Input form for prediction
st.header('Predict Depreciation')
car_model = st.text_input('Car Model',value="Honda Civic")
car_year = st.number_input('Car Year', value=2022)
assessment_year = st.number_input('Assessment Year', min_value=1, max_value=5, value=1)
starting_asset_value = st.number_input('Starting Asset Value', min_value=0, value=20000)
book_residual_value = st.number_input('Book Residual Value', min_value=0, value=18000)
market_value = st.number_input('Market Value', min_value=0, value=19000)

if st.button('Predict'):
    input_data = {
        'Car_Model': car_model,
        'Car_Year': int(car_year),
        'Assessment_Year': int(assessment_year),
        'Starting_Asset_Value': float(starting_asset_value),
        'Book_Residual_Value': float(book_residual_value),
        'Market_Value': float(market_value)
    }
    
    response = requests.post(f'{API_URL}/predict/', json=input_data)
    if response.status_code == 200:
        prediction = response.json()['predicted_depreciation']
        st.success(f'Predicted Depreciation: {prediction:.2f}%')
    elif response.status_code == 404:
        st.error('Model not found. Please create a model first.')
    else:
        st.error(f'Error making prediction: {response.json().get("error", "Unknown error")}')

if __name__ == '__main__':
    try:
        # Start Flask in a separate thread
        threading.Thread(target=lambda: app.run(debug=False, use_reloader=False)).start()

        # Run Streamlit
        import streamlit.web.cli as stcli
        import sys
        
        sys.argv = ["streamlit", "run", __file__]
        sys.exit(stcli.main())
    except:
        print("An exception occurred")