import os
from langchain_aws import ChatBedrock
import boto3

from dotenv import load_dotenv
from pathlib import Path
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from rl_environment import GeneralizedRLEnvironment

env_path = Path('.') / '.env'
load_dotenv(dotenv_path=env_path)

model_id = "us.anthropic.claude-3-5-haiku-20241022-v1:0"
region_name = os.environ.get("BEDROCK_MODEL_REGION", "us-east-1")

def get_llm():
    # config = Config(
    #     retries = {
    #         'max_attempts': 3,
    #         'mode': 'standard'
    #     },
    #     read_timeout=120,  # Increase the timeout if needed
    #     connect_timeout=120
    # )

    bedrock_client = boto3.client('bedrock-runtime', region_name=region_name)
    llm = ChatBedrock(
        client=bedrock_client,
        model_id=model_id,
        region_name=region_name,
        model_kwargs={"temperature": 0.1}
    )
    print("get_llm | region_name:", region_name)
    return llm

# number of input parameters for model is missing.
def process_input_json(observable_factors, constant_factor, model_selection, actions, boundaries):
    processed_json = {
        "state": {
            "observable_factors": [],
            "constant_factors": {},
            "model_predictions": []
        },
        "actions": {
            "action_space": [],
            "RL_boundaries": {}
        },
        "rl_agent": "",
        "environment": {
            "step": [],
            "reward": []
        }
    }

    for data in observable_factors.get('dataAccess', []):
        db_name = data['label']

        for column in data.get('children', []):
            processed_json['state']['observable_factors'].append({'name': column['label'],
                                                                #   'name': db_name + "__" + column['label'],
                                                                  'type': column['metadata']['DataType']})

    for data in observable_factors.get('workspace', []):
        db_name = data['label']

        for column in data.get('children', []):
            processed_json['state']['observable_factors'].append(
                {'name': column['label'], 'type': column['metadata']['DataType']})
                # {'name': db_name + "__" + column['label'], 'type': column['metadata']['DataType']})

    for constant in constant_factor:
        processed_json['state']['constant_factors'][constant['name'].replace(
            " ", "_")] = constant['value']

    for model in model_selection:
        for target in model.get('targetColumns', []):
            processed_json['state']['model_predictions'].append({"name": target,
                                                                 "model_type": model['modelType'],
                                                                 "number_of_values_to_derive": model['forecastHorizon'],
                                                                #  "model_id": model['id'],
                                                                #  "interval": model['interval']
                                                                 })

    for action in actions:
        my_obj = {"name": action['actionName'].replace(" ", "_"),
                  "type": action['dataType'],
                  }

        if 'values' in action:
            my_obj['my_list'] = action['values']

        processed_json['actions']['action_space'].append(my_obj)

    starting_values = {}
    for boundary in boundaries:
        # name = boundary['dataset']['name'] + "__" + boundary['targetColumn']['name']
        name = boundary['targetColumn']['name']
        processed_json['actions']['RL_boundaries'][name] = [
            boundary['lowerLimit'], boundary['upperLimit']]

        if boundary['startingValue']:
            starting_values[name] = boundary['startingValue']

    print("STARTING VALUES", starting_values)
    for indx, obs in enumerate(processed_json['state']['observable_factors']):
        print(obs['name'])
        if obs['name'] in starting_values:
            print("YES")
            processed_json['state']['observable_factors'][indx]['starting_value'] = starting_values[obs['name']]

    return processed_json

def evalute_final_json(json: dict):
    env2 = GeneralizedRLEnvironment(json)
    model = PPO("MlpPolicy", env2, verbose=1)
    mean_reward, std_reward = evaluate_policy(model, env2, n_eval_episodes=1)

    return mean_reward, std_reward

def get_variables_and_boundaries(json: dict):
    details = """
| Variable Name | Type | Number of Values in List | Values of list | Lower Limit | Upper Limit |
| --- | --- | --- | --- | --- | --- |"""
    
    row_details = "| {variable_name} | {type} | {number_of_values} | {values_of_list} | {lower_limit} | {upper_limit} |"

    for obs in json['state']['observable_factors']:
        details += "\n" + row_details.format(variable_name=obs['name'], type=obs['type'], number_of_values="N/A",values_of_list = "N/A", lower_limit=json['actions']['RL_boundaries'][obs['name']][0], upper_limit=json['actions']['RL_boundaries'][obs['name']][1])
    
    for constant in json['state']['constant_factors']:
        details += "\n" + row_details.format(variable_name=constant, type="double", number_of_values="N/A",values_of_list = "N/A", lower_limit="N/A", upper_limit="N/A")
        
    for model in json['state']['model_predictions']:
        details += "\n" + row_details.format(variable_name=model['name'], type='list' if model['model_type'] == 'Train-Time-Series' else 'double', number_of_values=model['number_of_values_to_derive'] if model['model_type'] == 'Train-Time-Series' else 'N/A',values_of_list = "N/A", lower_limit=json['actions']['RL_boundaries'][model['name']][0], upper_limit=json['actions']['RL_boundaries'][model['name']][1])
        
    for action in json['actions']['action_space']:
        details += "\n" + row_details.format(variable_name=action['name'], type=action['type'], number_of_values="N/A",values_of_list = action['my_list'] if 'my_list' in action else "N/A", lower_limit=json['actions']['RL_boundaries'][action['name']][0], upper_limit=json['actions']['RL_boundaries'][action['name']][1])
        
    return details