sultan-hassan commited on
Commit
15411db
·
verified ·
1 Parent(s): 46b89af

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +69 -0
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ChatState():
2
+ """
3
+ Manages the conversation history for a turn-based chatbot
4
+ Follows the turn-based conversation guidelines for the Gemma family of models
5
+ documented at https://ai.google.dev/gemma/docs/formatting
6
+ """
7
+
8
+
9
+ __START_TURN_USER__ = "Instruction:\n"
10
+ __START_TURN_MODEL__ = "\n\nResponse:\n"
11
+ __END_TURN__ = ""#"\n"
12
+
13
+
14
+ def __init__(self, model, system=""):
15
+ """
16
+ Initializes the chat state.
17
+
18
+ Args:
19
+ model: The language model to use for generating responses.
20
+ system: (Optional) System instructions or bot description.
21
+ """
22
+ self.model = model
23
+ self.system = system
24
+ self.history = []
25
+
26
+ def add_to_history_as_user(self, message):
27
+ """
28
+ Adds a user message to the history with start/end turn markers.
29
+ """
30
+ self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)
31
+
32
+ def add_to_history_as_model(self, message):
33
+ """
34
+ Adds a model response to the history with the start turn marker.
35
+ Model will generate end turn marker.
36
+ """
37
+ self.history.append(self.__START_TURN_MODEL__ + message+ "\n")
38
+
39
+ def get_history(self):
40
+ """
41
+ Returns the entire chat history as a single string.
42
+ """
43
+ return "".join([*self.history])
44
+
45
+ def get_full_prompt(self):
46
+ """
47
+ Builds the prompt for the language model, including history and system description.
48
+ """
49
+ prompt = self.get_history() + self.__START_TURN_MODEL__
50
+ if len(self.system)>0:
51
+ prompt = self.system + "\n" + prompt
52
+ return prompt
53
+
54
+ def send_message(self, message):
55
+ """
56
+ Handles sending a user message and getting a model response.
57
+
58
+ Args:
59
+ message: The user's message.
60
+
61
+ Returns:
62
+ The model's response.
63
+ """
64
+ self.add_to_history_as_user(message)
65
+ prompt = self.get_full_prompt()
66
+ response = self.model.generate(prompt, max_length=4096)
67
+ result = response.replace(prompt, "") # Extract only the new response
68
+ self.add_to_history_as_model(result)
69
+ return result