import argparse
import numpy as np

from rich import print

from bark_infinity import config

logger = config.logger

from bark_infinity import generation
from bark_infinity import api

from bark_infinity import text_processing
import time

import random

text_prompts_in_this_file = []


import torch
from torch.utils import collect_env


try:
    text_prompts_in_this_file.append(
        f"It's {text_processing.current_date_time_in_words()} And if you're hearing this, Bark is working. But you didn't provide any text"
    )
except Exception as e:
    print(f"An error occurred: {e}")

text_prompt = """
    In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move. However, Bark is working.
"""
text_prompts_in_this_file.append(text_prompt)

text_prompt = """
    A common mistake that people make when trying to design something completely foolproof is to underestimate the ingenuity of complete fools.
"""
text_prompts_in_this_file.append(text_prompt)


def get_group_args(group_name, updated_args):
    # Convert the Namespace object to a dictionary
    updated_args_dict = vars(updated_args)

    group_args = {}
    for key, value in updated_args_dict.items():
        if key in dict(config.DEFAULTS[group_name]):
            group_args[key] = value
    return group_args


def main(args):
    if args.loglevel is not None:
        logger.setLevel(args.loglevel)

    if args.OFFLOAD_CPU is not None:
        generation.OFFLOAD_CPU = args.OFFLOAD_CPU
        # print(f"OFFLOAD_CPU is set to {generation.OFFLOAD_CPU}")
    else:
        if generation.get_SUNO_USE_DIRECTML() is not True:
            generation.OFFLOAD_CPU = True  # default on just in case
    if args.USE_SMALL_MODELS is not None:
        generation.USE_SMALL_MODELS = args.USE_SMALL_MODELS
        # print(f"USE_SMALL_MODELS is set to {generation.USE_SMALL_MODELS}")
    if args.GLOBAL_ENABLE_MPS is not None:
        generation.GLOBAL_ENABLE_MPS = args.GLOBAL_ENABLE_MPS
        # print(f"GLOBAL_ENABLE_MPS is set to {generation.GLOBAL_ENABLE_MPS}")

    if not args.silent:
        if args.detailed_gpu_report or args.show_all_reports:
            print(api.startup_status_report(quick=False))
        elif not args.text_prompt and not args.prompt_file:  # probably a test run, default to show
            print(api.startup_status_report(quick=True))
        if args.detailed_hugging_face_cache_report or args.show_all_reports:
            print(api.hugging_face_cache_report())
        if args.detailed_cuda_report or args.show_all_reports:
            print(api.cuda_status_report())
        if args.detailed_numpy_report:
            print(api.numpy_report())
        if args.run_numpy_benchmark or args.show_all_reports:
            from bark_infinity.debug import numpy_benchmark

            numpy_benchmark()

    if args.list_speakers:
        api.list_speakers()
        return

    if args.render_npz_samples:
        api.render_npz_samples()
        return

    if args.text_prompt:
        text_prompts_to_process = [args.text_prompt]
    elif args.prompt_file:
        text_file = text_processing.load_text(args.prompt_file)
        if text_file is None:
            logger.error(f"Error loading file: {args.prompt_file}")
            return
        text_prompts_to_process = text_processing.split_text(
            text_processing.load_text(args.prompt_file),
            args.split_input_into_separate_prompts_by,
            args.split_input_into_separate_prompts_by_value,
        )

        print(f"\nProcessing file: {args.prompt_file}")
        print(f"  Looks like: {len(text_prompts_to_process)} prompt(s)")

    else:
        print("No --text_prompt or --prompt_file specified, using test prompt.")
        text_prompts_to_process = random.sample(text_prompts_in_this_file, 2)

    things = len(text_prompts_to_process) + args.output_iterations
    if things > 10:
        if args.dry_run is False:
            print(
                f"WARNING: You are about to process {things} prompts. Consider using '--dry-run' to test things first."
            )

    # pprint(args)
    print("Loading Bark models...")
    if not args.dry_run and generation.get_SUNO_USE_DIRECTML() is not True:
        generation.preload_models(
            args.text_use_gpu,
            args.text_use_small,
            args.coarse_use_gpu,
            args.coarse_use_small,
            args.fine_use_gpu,
            args.fine_use_small,
            args.codec_use_gpu,
            args.force_reload,
        )

    print("Done.")

    for idx, text_prompt in enumerate(text_prompts_to_process, start=1):
        if len(text_prompts_to_process) > 1:
            print(f"\nPrompt {idx}/{len(text_prompts_to_process)}:")

        # print(f"Text prompt: {text_prompt}")
        for iteration in range(1, args.output_iterations + 1):
            if args.output_iterations > 1:
                print(f"\nIteration {iteration} of {args.output_iterations}.")
                if iteration == 1:
                    print("ss", text_prompt)

            args.current_iteration = iteration
            args.text_prompt = text_prompt
            args_dict = vars(args)

            api.generate_audio_long(**args_dict)


if __name__ == "__main__":
    parser = config.create_argument_parser()

    args = parser.parse_args()

    updated_args = config.update_group_args_with_defaults(args)

    namespace_args = argparse.Namespace(**updated_args)
    main(namespace_args)