File size: 1,014 Bytes
c9720d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from transformers import pipeline

# MODEL_CKPT = "HuggingFaceH4/zephyr-7b-beta"
MODEL_CKPT = "AVeryRealHuman/DialoGPT-small-TonyStark"

class HFAgent:
    def __init__(self):
        self.pipe = pipeline("conversational", MODEL_CKPT)
    
    def generate(self, chat_history):
        return self.pipe(chat_history)
    
    def __call__(self, chat_history):
        return self.generate(chat_history)
    
    def __repr__(self):
        return f"HFAgent(model={self.pipe.model})"
    
    def __str__(self):
        return f"HFAgent(model={self.pipe.model})"
    

## For testing purposes
# def main():
#     agent = HFAgent()
#     messages = [
#         {
#             "role": "system",
#             "content": "You are a friendly chatbot who always responds in the style of a pirate",
#         },
#         {"role": "user", "content": "How many hotdogs can a human eat in one sitting?"},
#     ]
#     new_messages = agent(messages)
#     print(new_messages[-1])

# if __name__ == "__main__":
#     main()