import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline

#import subprocess
# Install flash attention, skipping CUDA build if necessary
#subprocess.run(
#    "pip install flash-attn --no-build-isolation",
#    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
#    shell=True,
#)

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

DESCRIPTION = """\
# Try Patched Chat
"""

LICENSE = """\
---
This space was created by [patched](https://patched.codes).
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    #model_id = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True,trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.padding_side = 'right'
    # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    # tokenizer.use_default_system_prompt = FalseQwen/CodeQwen1.5-7B-Chat
    
@spaces.GPU(duration=60)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.2,
    top_p: float = 0.95,
    # top_k: int = 50,
    # repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

#    prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
#    outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, 
#                   eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

#    return outputs[0]['generated_text'][len(prompt):].strip()
    
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)
    
    terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        #top_k=top_k,
        temperature=temperature,
        eos_token_id=terminators,
        #eos_token_id=tokenizer.eos_token_id, 
        #pad_token_id=tokenizer.pad_token_id,
        #num_beams=1,
        #repetition_penalty=1.2,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

example1='''Fix vulnerability CWE-327: Use of a Broken or Risky Cryptographic Algorithm in the following code snippet.

def md5_hash(path):
     with open(path, "rb") as f:
         content = f.read()
     return hashlib.md5(content).hexdigest()
'''

example2='''Carefully analyze the given old code and new code and generate a summary of the changes.

Old Code:

#include <stdio.h>
#include <stdlib.h>

typedef struct Node {
    int data;
    struct Node *next;
} Node;

void processList() {
    Node *head = (Node*)malloc(sizeof(Node));
    head->data = 1;
    head->next = (Node*)malloc(sizeof(Node));
    head->next->data = 2;

    printf("First element: %d\n", head->data);

    free(head->next); 
    free(head); 

    printf("Accessing freed list: %d\n", head->next->data);
}

New Code:

#include <stdio.h>
#include <stdlib.h>

typedef struct Node {
    int data;
    struct Node *next;
} Node;

void processList() {
    Node *head = (Node*)malloc(sizeof(Node));
    if (head == NULL) {
        perror("Failed to allocate memory for head");
        return;
    }

    head->data = 1;
    head->next = (Node*)malloc(sizeof(Node));
    if (head->next == NULL) {
        free(head);
        perror("Failed to allocate memory for next node");
        return;
    }
    head->next->data = 2;

    printf("First element: %d\n", head->data);

    free(head->next); 
    head->next = NULL; 
    free(head); 
    head = NULL; 

    if (head != NULL && head->next != NULL) {
        printf("Accessing freed list: %d\n", head->next->data);
    }
}
'''

example3='''Is the following code prone to CWE-117: Improper Output Neutralization for Logs. Respond only with YES or NO.

from flask import Flask, request, jsonify
import logging

app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@app.route('/api/data', methods=['GET'])
def get_data():
    api_key = request.args.get('api_key')
    logger.info("Received request with API Key: %s", api_key)  
    data = {"message": "Data processed"}
    return jsonify(data)
'''

example4='''Fix vulnerability CWE-78: Improper Neutralization of Special Elements used in an OS Command ('OS Command Injection') in the following code snippet.

def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
    if desc is not None:
        print(desc)
    run_kwargs = {{
        "args": command,
        "shell": True,
        "env": os.environ if custom_env is None else custom_env,
        "encoding": 'utf8',
        "errors": 'ignore',
    }}
    if not live:
        run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
    result = subprocess.run(**run_kwargs)  ##here
    if result.returncode != 0:
        error_bits = [
            f"{{errdesc or 'Error running command'}}.",
            f"Command: {{command}}",
            f"Error code: {{result.returncode}}",
        ]
        if result.stdout:
            error_bits.append(f"stdout: {{result.stdout}}")
        if result.stderr:
            error_bits.append(f"stderr: {{result.stderr}}")
        raise RuntimeError("\n".join(error_bits))
    return (result.stdout or "")
'''

chat_interface = gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(height="480px"),
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=4),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.2,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        ),
    ],
    stop_btn=None,
    examples=[
        ["You are a helpful coding assistant. Create a snake game in Python."],
        [example1],
        [example2],
        [example3],
        [example4],
    ],
)

with gr.Blocks(css="style.css",) as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()
    gr.Markdown(LICENSE, elem_classes="contain")

if __name__ == "__main__":
    demo.queue(max_size=20).launch()