File size: 1,699 Bytes
4bdab37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import unittest
from unittest import TestCase

from chatarena.message import Message, MessagePool


class TestMessagePool(TestCase):
    def test_message_fully_observable(self):
        message_pool = MessagePool()
        p1_message = Message("player1", "I'm player 1", 1)
        p2_message = Message("player2", "I'm player 2", 1)

        message_pool.append_message(p1_message)
        message_pool.append_message(p2_message)
        p1_observation = message_pool.get_visible_messages("player1", 2)
        assert p1_observation[0].msg_hash == p1_message.msg_hash
        assert p1_observation[1].msg_hash == p2_message.msg_hash

    def test_message_by_turn(self):
        message_pool = MessagePool()
        p1_message = Message("player1", "I'm player 1", 1)
        p2_message = Message("player2", "I'm player 2", 2)
        message_pool.append_message(p1_message)
        message_pool.append_message(p2_message)
        p1_observation = message_pool.get_visible_messages("player1", 2)
        assert p1_observation[0].msg_hash == p1_message.msg_hash
        assert len(p1_observation) == 1

    def test_message_partial_observation(self):
        message_pool = MessagePool()
        p1_message = Message("player1", "I'm player 1", 1)
        p2_message = Message("player2", "I'm player 2", 1, visible_to=["player2"])

        message_pool.append_message(p1_message)
        message_pool.append_message(p2_message)
        p1_observation = message_pool.get_visible_messages("player1", 2)
        p2_observation = message_pool.get_visible_messages("player2", 2)
        assert len(p1_observation) == 1
        assert len(p2_observation) == 2


if __name__ == "__main__":
    unittest.main()