chatntq-7b-jpntuned Card

Model Details

chatntq-7b-jpntuned is a chat assistant trained by fine-tuning BlinkDL/rwkv-4-world on user-shared conversations collected from ShareGPT.

Uses

import os, gc, copy, torch
import gradio as gr
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'
from rwkv.model import RWKV
model_path = "chatntq-7b-jpntuned/ChatNTQ-7B-RWKV-4-World-JPNtuned-ctx2048.pth"
WORD_NAME = "rwkv_vocab_v20230424" # copy rwkv_vocab_v20230424.txt in chatntq-7b-jpntuned to the same folder test
ctx_limit = 1024
model = RWKV(model=model_path, strategy='cuda fp16i8 *24 -> cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, WORD_NAME)
def generate_prompt(instruction):
    return f"\x00Human: {instruction}\x00Assistant:  "

def evaluate(
    prompt,
    token_count=1024,
    temperature=1.2,
    top_p=0.5,
    presencePenalty = 0.4,
    countPenalty = 0.4,
):
    args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
                     alpha_frequency = countPenalty,
                     alpha_presence = presencePenalty,
                     token_ban = [], # ban the generation of some tokens
                     token_stop = [0,1]) # stop generation whenever you see any token here

    all_tokens = []
    out_last = 0
    out_str = ''
    occurrence = {}
    state = None
    prompt = generate_prompt(prompt)
    print(prompt)
    for i in range(int(token_count)):
        out, state = model.forward(pipeline.encode(prompt)[-ctx_limit:] if i == 0 else [token], state)
        for n in occurrence:
            out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)

        token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
        if token in args.token_stop:
            break
        all_tokens += [token]
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1

        tmp = pipeline.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp:
            out_str += tmp
            out_last = i + 1
    gc.collect()
    torch.cuda.empty_cache()
    return out_str
if __name__ == "__main__":
  question = "東京の人口はどれくらいですか?"
  response = evaluate(question)

Contact information

For personal communication related to this project, please contact Nha Nguyen Van ([email protected]).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Dataset used to train NTQAI/chatntq-7b-jpntuned

Collection including NTQAI/chatntq-7b-jpntuned