|
|
|
import time |
|
import unittest |
|
from unittest.mock import patch |
|
|
|
from autogpt.chat import create_chat_message, generate_context |
|
|
|
|
|
class TestChat(unittest.TestCase): |
|
|
|
def test_happy_path_role_content(self): |
|
result = create_chat_message("system", "Hello, world!") |
|
self.assertEqual(result, {"role": "system", "content": "Hello, world!"}) |
|
|
|
|
|
def test_empty_role_content(self): |
|
result = create_chat_message("", "") |
|
self.assertEqual(result, {"role": "", "content": ""}) |
|
|
|
|
|
@patch("time.strftime") |
|
def test_generate_context_empty_inputs(self, mock_strftime): |
|
|
|
mock_strftime.return_value = "Sat Apr 15 00:00:00 2023" |
|
|
|
prompt = "" |
|
relevant_memory = "" |
|
full_message_history = [] |
|
model = "gpt-3.5-turbo-0301" |
|
|
|
|
|
result = generate_context(prompt, relevant_memory, full_message_history, model) |
|
|
|
|
|
expected_result = ( |
|
-1, |
|
47, |
|
3, |
|
[ |
|
{"role": "system", "content": ""}, |
|
{ |
|
"role": "system", |
|
"content": f"The current time and date is {time.strftime('%c')}", |
|
}, |
|
{ |
|
"role": "system", |
|
"content": f"This reminds you of these events from your past:\n\n\n", |
|
}, |
|
], |
|
) |
|
self.assertEqual(result, expected_result) |
|
|
|
|
|
def test_generate_context_valid_inputs(self): |
|
|
|
prompt = "What is your favorite color?" |
|
relevant_memory = "You once painted your room blue." |
|
full_message_history = [ |
|
create_chat_message("user", "Hi there!"), |
|
create_chat_message("assistant", "Hello! How can I assist you today?"), |
|
create_chat_message("user", "Can you tell me a joke?"), |
|
create_chat_message( |
|
"assistant", |
|
"Why did the tomato turn red? Because it saw the salad dressing!", |
|
), |
|
create_chat_message("user", "Haha, that's funny."), |
|
] |
|
model = "gpt-3.5-turbo-0301" |
|
|
|
|
|
result = generate_context(prompt, relevant_memory, full_message_history, model) |
|
|
|
|
|
self.assertIsInstance(result[0], int) |
|
self.assertIsInstance(result[1], int) |
|
self.assertIsInstance(result[2], int) |
|
self.assertIsInstance(result[3], list) |
|
self.assertGreaterEqual(result[0], 0) |
|
self.assertGreaterEqual(result[1], 0) |
|
self.assertGreaterEqual(result[2], 0) |
|
self.assertGreaterEqual( |
|
len(result[3]), 3 |
|
) |
|
self.assertLessEqual( |
|
result[1], 2048 |
|
) |
|
|