# %%
import torch
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
    AutoModelForMaskedLM,
    AutoTokenizer,
    BertModel,
)
import numpy as np
import random
from itertools import islice
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
import os

model_name = "tohoku-nlp/bert-base-japanese-char-v3"
tokenizer = BertTokenizer.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)


class punctuation_predictor(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.dropout = torch.nn.Dropout(0.2)
        self.linear = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.base_model(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state
        # get last hidden state token by token and apply linear layer
        return self.linear(self.dropout(last_hidden_state))


model = punctuation_predictor(base_model)
model.load_state_dict(torch.load("weight/punctuation_position_model.pth"))
model.eval()


def insert_punctuation(input, comma_pos, period_pos):
    text = []
    for i, (c, p) in enumerate(zip(comma_pos, period_pos)):
        token_id = input[i].item()
        if token_id > 5:
            if i < len(input) - 1:
                if p:
                    text.append(tokenizer.ids_to_tokens[input[i].item()] + "。")
                elif c:
                    text.append(tokenizer.ids_to_tokens[input[i].item()] + "、")
                else:
                    text.append(tokenizer.ids_to_tokens[input[i].item()])
            else:
                break
    return "".join(text)


def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1):
    text = text.replace("、", "").replace("。", "")
    result = ""
    for i in range(0, len(text), max_length):
        no_punctuation_text = text[i : i + max_length]
        inputs = tokenizer(
            " ".join(list(no_punctuation_text)),
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        output = model(inputs.input_ids, inputs.attention_mask)
        output = torch.sigmoid(output)
        comma_pos = output[0].detach().numpy().T[0] > comma_thresh
        period_pos = output[0].detach().numpy().T[1] > period_thresh
        result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos)
    return result


# %%
if __name__ == "__main__":
    print(
        process_long_text(
            "女は昨夕艶めかしい姿をして彼の浴室の戸を開けた人に違なかった風呂場で彼を驚ろかした大きな髷をいつの間にか崩して尋常の束髪に結い更えたので彼はつい同じ人と気がつかずにいた彼はさらに声を聴いただけで顔を知らなかった伴の男の方をよそながらの初対面といった風に女と眺め比べた",
        )
    )