File size: 1,777 Bytes
8b19012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

def generate_prompt(instruction, input=""):
    instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
    input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
    if input:
        return f"""Instruction: {instruction}

Input: {input}

Response:"""
    else:
        return f"""User: hi

Lover: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.

User: {instruction}

Lover:"""

model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory

model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    trust_remote_code=True, 
    use_flash_attention_2=False  # Explicitly disable Flash Attention
).to(torch.float32)


tokenizer = AutoTokenizer.from_pretrained(
    model_path, 
    bos_token="</s>",
    eos_token="</ s>",
    unk_token="<unk>",
    pad_token="<pad>",
    trust_remote_code=True, 
    padding_side='left', 
    clean_up_tokenization_spaces=False  # Or set to True if you prefer
)

print(tokenizer.special_tokens_map)

text = "Hi"

prompt = generate_prompt(text)

input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# Generate text word by word with stop sequence
generated_text = ""
for i in range(333):  # Generate up to 333 tokens
    output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
    new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)

    print(new_word, end="", flush=True)  # Print word-by-word
    generated_text += new_word

    input_ids = output  # Update input_ids for next iteration

print()  # Add a newline at the end