File size: 3,817 Bytes
62100c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4ed4fa
62100c8
57543bf
 
 
 
 
 
62100c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879656b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Adapted from the Gradio tutorials:
# https://www.gradio.app/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face

import gradio as gr

import torch

# Get cpu, gpu or mps device for training.
# See: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import StoppingCriteria
from transformers import StoppingCriteriaList
from transformers import TextIteratorStreamer

from threading import Thread

MODEL_ID = "togethercomputer/RedPajama-INCITE-Chat-3B-v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
model = model.to(device) # move model to GPU

class StopOnTokens(StoppingCriteria):
    """
    Class used `stopping_criteria` in `generate_kwargs` that provides an additional
    way of stopping the generation loop (if this class returns `True` on a token,
    the generation is stopped)).
    """
    # note: Python now supports type hints, see this: https://realpython.com/lessons/type-hinting/
    #       (for the **kwargs see also: https://realpython.com/python-kwargs-and-args/)
    # this could also be written: def __call__(self, input_ids, scores, **kwargs):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [29, 0] # see the cell below to understand where these come from
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

def predict(message, history):

    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # useful to debug
    # msg = "history"
    # print(msg)
    # print(*history_transformer_format, sep="\n")
    # print("***")    

    # at each step, we feed the entire history in string format,
    # restoring the format used in their dataset with new lines
    # and <human>: or <bot>: added before the messages
    messages = "".join(
        ["".join(
            ["\n<human>:"+item[0], "\n<bot>:"+item[1]]
         )
        for item in history_transformer_format]
    )
    # # to see what we feed to our net:
    # msg = "string prompt"
    # print(msg)
    # print("-" * len(msg))
    # print(messages)
    # print("-" * 40)
 
    # convert the string into tensors & move to GPU
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(
        tokenizer,
        # timeout=30., # no timeout until I implement error handling for the empty stream
        skip_prompt=True,
        skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=1.0,
        pad_token_id=tokenizer.eos_token_id, # mute annoying warning: https://stackoverflow.com/a/71397707
        num_beams=1,  # this is for beam search (disabled), see: https://huggingface.co/blog/how-to-generate#beam-search
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_message  = ""
    for new_token in streamer:
        # seen the format <human>: and \n<bot> above (when 'messages' is defined)?
        # we stream the message *until* we encounter '<', which is by the end
        if new_token != '<':
            partial_message += new_token
            yield partial_message


gr.ChatInterface(predict).queue().launch(debug=True, share=True)