#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from threading import Thread

import gradio as gr
from transformers import AutoModel, AutoTokenizer
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
import torch

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_subset", default="train.jsonl", type=str)
    parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
    parser.add_argument(
        "--pretrained_model_name_or_path",
        default=(project_path / "trained_models/qwen_7b_chinese_modern_poetry").as_posix(),
        type=str
    )
    parser.add_argument("--output_file", default="result.xlsx", type=str)

    parser.add_argument("--max_new_tokens", default=512, type=int)
    parser.add_argument("--top_p", default=0.9, type=float)
    parser.add_argument("--temperature", default=0.35, type=float)
    parser.add_argument("--repetition_penalty", default=1.0, type=float)
    parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)

    args = parser.parse_args()
    return args


description = """
## Qwen-7B

基于 [Qwen-7B](https://huggingface.co/qgyd2021/Qwen-7B) 模型, 在 [chinese_modern_poetry](https://huggingface.co/datasets/Iess/chinese_modern_poetry) 数据集上训练了 2 个 epoch. 

可用于生成现代诗. 如下: 
使用下列意象写一首现代诗:智慧,刀刃. 
"""


examples = [
    "使用下列意象写一首现代诗:石头,森林",
    "使用下列意象写一首现代诗:花,纱布",
    "使用下列意象写一首现代诗:山壁,彩虹,诗句,山坡,泪",
    "使用下列意象写一首现代诗:味道,黄金,名字,银子,女人",
    "使用下列意象写一首现代诗:乳房,触感,车速,星星,路灯"

]


def main():
    args = get_args()

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, trust_remote_code=True)
    # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
    if tokenizer.__class__.__name__ == "QWenTokenizer":
        tokenizer.pad_token_id = tokenizer.eod_id
        tokenizer.bos_token_id = tokenizer.eod_id
        tokenizer.eos_token_id = tokenizer.eod_id

    model = AutoModelForCausalLM.from_pretrained(
        args.pretrained_model_name_or_path,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        offload_folder="./offload",
        offload_state_dict=True,
        # load_in_4bit=True,
    )
    model = model.bfloat16().eval()

    def fn_non_stream(text: str):
        input_ids = tokenizer(
            text,
            return_tensors="pt",
            add_special_tokens=False,
        ).input_ids.to(args.device)
        bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device)
        eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device)
        input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)

        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                max_new_tokens=args.max_new_tokens,
                do_sample=True,
                top_p=args.top_p,
                temperature=args.temperature,
                repetition_penalty=args.repetition_penalty,
                eos_token_id=tokenizer.eos_token_id
            )
            outputs = outputs.tolist()[0][len(input_ids[0]):]
            response = tokenizer.decode(outputs)
            response = response.strip().replace(tokenizer.eos_token, "").strip()

        return [(text, response)]

    def fn_stream(text: str):
        text = str(text).strip()

        input_ids = tokenizer(
            text,
            return_tensors="pt",
            add_special_tokens=False,
        ).input_ids.to(args.device)
        bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device)
        eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device)
        input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)

        streamer = TextIteratorStreamer(tokenizer=tokenizer)

        generation_kwargs = dict(
            inputs=input_ids,
            max_new_tokens=args.max_new_tokens,
            do_sample=True,
            top_p=args.top_p,
            temperature=args.temperature,
            repetition_penalty=args.repetition_penalty,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            streamer=streamer,
        )
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        output = ""
        for output_ in streamer:
            output_ = output_.replace(text, "")
            output_ = output_.replace(tokenizer.eos_token, "")

            output += output_

            result = [(text, output)]
            chatbot.value = result
            yield result

    with gr.Blocks() as blocks:
        gr.Markdown(value=description)

        chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
        with gr.Row():
            with gr.Column(scale=4):
                text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
            with gr.Column(scale=1):
                submit_button = gr.Button("💬Submit")
            with gr.Column(scale=1):
                clear_button = gr.Button("🗑️Clear", variant="secondary")

        gr.Examples(examples, text_box)

        text_box.submit(fn_stream, [text_box], [chatbot])
        submit_button.click(fn_stream, [text_box], [chatbot])
        clear_button.click(
            fn=lambda: ("", ""),
            outputs=[text_box, chatbot],
            queue=False,
            api_name=False,
        )

    blocks.queue().launch()

    return


if __name__ == '__main__':
    main()