import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl  # For file locking
import os  # For file operations
import time  # For sleep function

# Set the max_split_size globally at the start
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# Print to verify the environment variable is correctly set
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")

# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
    """
    Format chat messages using Qwen's chat template.
    """
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def model_fn(model_dir, context=None):
    global model, tokenizer

    # Path to lock file for ensuring single loading
    lock_file = "/tmp/model_load.lock"
    # Path to in-progress file indicating model loading is happening
    in_progress_file = "/tmp/model_loading_in_progress"

    if model is not None and tokenizer is not None:
        print("Model and tokenizer already loaded, skipping reload.")
        return model, tokenizer

    # Attempt to acquire the lock
    with open(lock_file, 'w') as lock:
        print("Attempting to acquire model load lock...")
        fcntl.flock(lock, fcntl.LOCK_EX)  # Exclusive lock

        try:
            # Check if another worker is in the process of loading
            if os.path.exists(in_progress_file):
                print("Another worker is currently loading the model, waiting...")

                # Poll the in-progress flag until the other worker finishes loading
                while os.path.exists(in_progress_file):
                    time.sleep(5)  # Wait for 5 seconds before checking again

                print("Loading complete by another worker, skipping reload.")
                return model, tokenizer

            # If no one is loading, start loading the model and set the in-progress flag
            print("No one is loading, proceeding to load the model.")
            with open(in_progress_file, 'w') as f:
                f.write("loading")

            # Loading the model and tokenizer
            if model is None or tokenizer is None:
                print("Loading the model and tokenizer...")

                offload_dir = "/tmp/offload_dir"
                os.makedirs(offload_dir, exist_ok=True)

                # Load and dispatch model across 4 GPUs using tensor parallelism
                model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
                model = load_checkpoint_and_dispatch(
                    model,
                    model_dir,
                    device_map="auto",  # Automatically map layers across GPUs
                    offload_folder=offload_dir,  # Offload parts to disk if needed
                    max_memory = {i: "15GiB" for i in range(torch.cuda.device_count())},  # Example for reducing usage per GPU
                    no_split_module_classes=["QwenForCausalLM"]  # Ensure model is split across the GPUs
                )

                # Load the tokenizer
                tokenizer = AutoTokenizer.from_pretrained(model_dir)

        except Exception as e:
            print(f"Error loading model and tokenizer: {e}")
            raise

        finally:
            # Remove the in-progress flag once the loading is complete
            if os.path.exists(in_progress_file):
                os.remove(in_progress_file)

            # Release the lock
            fcntl.flock(lock, fcntl.LOCK_UN)

    return model, tokenizer

# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer, context=None):
    """
    Generate predictions for the input data.
    """
    try:
        model, tokenizer = model_and_tokenizer
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages, tokenizer)

        # Tokenize the input
        inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")  # Send input to GPU 0 for generation

        # Generate output
        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=data.get("max_new_tokens", 512),
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            do_sample=True
        )

        # Decode the output
        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        # Build response
        response = {
            "id": "chatcmpl-uuid",
            "object": "chat.completion",
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(inputs['input_ids'][0]),
                "completion_tokens": len(outputs[0]),
                "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
            }
        }
        return response

    except Exception as e:
        return {"error": str(e), "details": repr(e)}

# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
    """
    Prepare the input data for inference.
    """
    return serialized_input_data

# Define output format for SageMaker
def output_fn(prediction_output, accept, context=None):
    """
    Convert the model output to a JSON response.
    """
    return json.dumps(prediction_output)