import os, gc
from typing import AsyncGenerator
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from asyncio import sleep

class Answerer:
  def __init__(self, model: str, vocab: str, strategy: str, ctx_limit: int):
    os.environ["RWKV_JIT_ON"] = "1"
    # os.environ["RWKV_CUDA_ON"] = "1"
    
    self.__model = RWKV(f"models/{model}.pth", strategy=strategy)
    self.__pipeline = PIPELINE(self.__model, vocab)
    self.ctx_limit = ctx_limit

  async def __call__(
    self,
    input: str,
    max_output_length_tk: int,
    chaos = .1,
    repetitiveness = .3,
    diversity = 0,
    _count_penalty = 1,
  ) -> AsyncGenerator[str, None]:
    args = PIPELINE_ARGS(
      temperature=chaos,
      top_p=repetitiveness,
      alpha_frequency=_count_penalty,
      alpha_presence=diversity,
      token_ban = [],
      token_stop = [0],
    )

    input = input.strip()

    result: str = ""
    
    occurrences: dict[int, int] = {}
    tokens: list[int] = []
    current_token = None
    state = None
    for _ in range(max_output_length_tk):
      out, state = self.__model.forward(
        [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
        state,
      )
      for token in occurrences:
        out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency

      current_token = self.__pipeline.sample_logits(
        out,
        temperature=args.temperature,
        top_p=args.top_p,
      )
      if current_token in args.token_stop: break

      tokens.append(current_token)

      for token in occurrences:
        occurrences[token] *= 0.996     

      if current_token in occurrences:
        occurrences[current_token] += 1
      else: 
        occurrences[current_token] = 1
      
      tmp: str = self.__pipeline.decode(tokens)
      if "\ufffd" not in tmp:
        tokens.clear()
        result += tmp
        if result.rstrip().endswith("\n\nUser:"):
          yield result.rstrip().removesuffix("\n\nUser:")
          break
        yield result
        await sleep(.02)

    tokens.clear()
    occurrences.clear()
    del out, tmp
    del occurrences, tokens, current_token, state 
    gc.collect()