Mattral commited on
Commit
60362b0
·
verified ·
1 Parent(s): 233f32d

Create src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +80 -0
src/model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator, List, Tuple
3
+
4
+ from text_generation import Client
5
+
6
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
7
+
8
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN", None)
10
+
11
+ client = Client(
12
+ API_URL,
13
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
14
+ )
15
+ EOS_STRING = "</s>"
16
+ EOT_STRING = "<EOT>"
17
+
18
+
19
+ def _get_prompt(
20
+ message: str, chat_history: List[Tuple[str, str]], system_prompt: str
21
+ ) -> str:
22
+ """
23
+ Get the prompt to send to the model.
24
+ :param message: The message to send to the model.
25
+ :param chat_history: The chat history.
26
+ :param system_prompt: The system prompt.
27
+ :return: The prompt to send to the model.
28
+ """
29
+ texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
30
+ do_strip = False
31
+ for user_input, response in chat_history:
32
+ user_input = user_input.strip() if do_strip else user_input
33
+ do_strip = True
34
+ texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
35
+ message = message.strip() if do_strip else message
36
+ texts.append(f"{message} [/INST]")
37
+ return "".join(texts)
38
+
39
+
40
+ def run(
41
+ message: str,
42
+ chat_history: List[Tuple[str, str]],
43
+ system_prompt: str,
44
+ max_new_tokens: int = 2048,
45
+ temperature: float = 0.1,
46
+ top_p: float = 0.9,
47
+ top_k: int = 50,
48
+ ) -> Iterator[str]:
49
+ """
50
+ Run the model.
51
+ :param message: The message to send to the model.
52
+ :param chat_history: The chat history.
53
+ :param system_prompt: The system prompt.
54
+ :param max_new_tokens: The maximum number of tokens to generate.
55
+ :param temperature: The temperature.
56
+ :param top_p: The top p.
57
+ :param top_k: The top k.
58
+ :return: The generated text.
59
+ """
60
+ prompt = _get_prompt(message, chat_history, system_prompt)
61
+
62
+ generate_kwargs = dict(
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ temperature=temperature,
68
+ )
69
+ stream = client.generate_stream(prompt, **generate_kwargs)
70
+ output = ""
71
+ for response in stream:
72
+ if any(
73
+ [end_token in response.token.text for end_token
74
+ in [EOS_STRING, EOT_STRING]]
75
+ ):
76
+ return output
77
+ else:
78
+ output += response.token.text
79
+ yield output
80
+ return output