File size: 3,090 Bytes
d0d12ff
c62cf54
0b8d742
d0d12ff
 
5bd57b6
b818b3f
4362d26
d0d12ff
4362d26
dad4689
d0d12ff
38ede89
b80761a
 
4362d26
1937eb3
 
c3144ec
 
d0d12ff
c3144ec
1937eb3
28d873f
1937eb3
 
 
28d873f
1937eb3
 
28d873f
 
 
 
 
 
 
 
 
 
 
 
d0d12ff
 
4362d26
92585dc
 
 
 
 
 
4362d26
 
 
 
c62cf54
28d873f
74d1efc
 
0faca03
 
1fc5a3a
0faca03
 
28d873f
74d1efc
3824c46
92585dc
28d873f
92585dc
 
 
28d873f
74b9951
74d1efc
 
28d873f
d0d12ff
92585dc
b818b3f
92585dc
d0d12ff
 
92585dc
c62cf54
92585dc
a8f6d87
46203c4
 
 
1b01c22
 
c62cf54
 
1b01c22
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from typing import Optional, Tuple, Any
from functools import partial
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from dataclasses import dataclass

torch.set_grad_enabled(False)
model_name = "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ"
token = os.environ['hf_token']
pipe = pipeline("text-generation", model=model_name, device="cuda")
generate_kwargs = {'max_new_tokens': 20}


system_prompt = '''You are given a partial input text for another AI chat interface. 
Propose auto-completion to the text. You have several roles:
- Fight under-specification.
- Complete text to save the user time.

Don't suggest anything if there are no good suggestions. 
Make sure the suggestions are valid completions of the text! Suggest only up to 5 words ahead. The scheme of your answer should be "answer1;answer2;answer3" (return between 0 to 4 answers).
Answers should be only the completions themselves. 
You will now get a blank message from the user and then after your answer, the user will give you the text to complete.

'''


extra_prompt = '''
Examples: 
(1)
User: "Help me write a sentiment analysis pipeline"
Assistant: "using huggingface;using NLTK;using python"

(2)
User: "My name is"
Assistant: "" (nothing much to contribute at this point. return nothing)

(3)
User: "Help me find a present for my"
Assistant: "girlfriend;mother;father;friend"
'''


start_messages = [
    {'role': 'system', 'content': system_prompt}, 
    {'role': 'user', 'content': '  '},
    {'role': 'assistant', 'content': '<Waiting for text>'}
]


# functions
@dataclass
class PastKV:
    past_key_values: Any = None

past_key_values = PastKV()


def past_kv_to_device(past_kv, device):
    return tuple((k.to(device).detach(), v.to(device).detach()) for k, v in past_kv)


@spaces.GPU
def set_past_key_values():
    model, tokenizer = pipe.model, pipe.tokenizer
    tokenized = tokenizer.apply_chat_template(start_messages, return_tensors='pt')    
    
    # Check that this is indeed a prefix of the entire message
    test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
    tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')    
    assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
    past_key_values.past_key_values = model(tokenized.to(model.device)).past_key_values
    return True
    

@spaces.GPU
def generate(text, past_key_values):
    messages = [
        *start_messages,
        {'role': 'user', 'content': text}
    ]
    response = pipe(messages, 
                    past_key_values=past_key_values.past_key_values, #past_kv_to_device(past_key_values, pipe.model.device), 
                    **generate_kwargs)[0]['generated_text']
    return response[-1]['content']

    
if __name__ == "__main__":
    with torch.no_grad():
        set_past_key_values()
        print(past_key_values)
        demo = gr.Interface(partial(generate), 
                            inputs="textbox", outputs="textbox")
        demo.launch()