diff options
Diffstat (limited to 'tests/replication')
-rw-r--r-- | tests/replication/slave/storage/_base.py | 6 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_account_data.py | 14 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 108 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_receipts.py | 6 |
4 files changed, 66 insertions, 68 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index a103e7be80..65df116efc 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -31,6 +31,7 @@ from tests.utils import setup_test_homeserver class TestReplicationClientHandler(ReplicationClientHandler): """Overrides on_rdata so that we can wait for it to happen""" + def __init__(self, store): super(TestReplicationClientHandler, self).__init__(store) self._rdata_awaiters = [] @@ -53,12 +54,11 @@ class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver( + self.addCleanup, "blue", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), + ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.hs.get_ratelimiter().send_message.return_value = (True, 0) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index adf226404e..87cc2b2fba 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -29,20 +29,14 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_user_account_data(self): - yield self.master_store.add_account_data_for_user( - USER_ID, TYPE, {"a": 1} - ) + yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) yield self.replicate() yield self.check( - "get_global_account_data_by_type_for_user", - [TYPE, USER_ID], {"a": 1} + "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.master_store.add_account_data_for_user( - USER_ID, TYPE, {"a": 2} - ) + yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) yield self.replicate() yield self.check( - "get_global_account_data_by_type_for_user", - [TYPE, USER_ID], {"a": 2} + "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index f5b47f5ec0..2ba80ccdcf 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -38,6 +38,7 @@ def patch__eq__(cls): def unpatch(): if eq is not None: cls.__eq__ = eq + return unpatch @@ -48,10 +49,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def setUp(self): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals - self.unpatches = [ - patch__eq__(_EventInternalMetadata), - patch__eq__(FrozenEvent), - ] + self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super(SlavedEventStoreTestCase, self).setUp() def tearDown(self): @@ -61,33 +59,27 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def test_get_latest_event_ids_in_room(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.replicate() - yield self.check( - "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id] - ) + yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) join = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", + type="m.room.member", + key=USER_ID, + membership="join", prev_events=[(create.event_id, {})], ) yield self.replicate() - yield self.check( - "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] - ) + yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist( - type="m.room.message", msgtype="m.text", body="Hello" - ) + msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") yield self.replicate() yield self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist( - type="m.room.redaction", redacts=msg.event_id - ) + redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) yield self.replicate() msg_dict = msg.get_dict() @@ -102,9 +94,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist( - type="m.room.message", msgtype="m.text", body="Hello" - ) + msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") yield self.replicate() yield self.check("get_event", [msg.event_id], msg) @@ -122,19 +112,29 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_invites(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) event = yield self.persist( type="m.room.member", key=USER_ID_2, membership="invite" ) yield self.replicate() - yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser( - ROOM_ID, USER_ID, "invite", event.event_id, - event.internal_metadata.stream_ordering - )]) + yield self.check( + "get_invited_rooms_for_user", + [USER_ID_2], + [ + RoomsForUser( + ROOM_ID, + USER_ID, + "invite", + event.event_id, + event.internal_metadata.stream_ordering, + ) + ], + ) @defer.inlineCallbacks def test_push_actions_for_user(self): - yield self.persist(type="m.room.create", creator=USER_ID) + yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.join", key=USER_ID, membership="join") yield self.persist( type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" @@ -146,40 +146,55 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0} + {"highlight_count": 0, "notify_count": 0}, ) yield self.persist( - type="m.room.message", msgtype="m.text", body="world", + type="m.room.message", + msgtype="m.text", + body="world", push_actions=[(USER_ID_2, ["notify"])], ) yield self.replicate() yield self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1} + {"highlight_count": 0, "notify_count": 1}, ) yield self.persist( - type="m.room.message", msgtype="m.text", body="world", - push_actions=[(USER_ID_2, [ - "notify", {"set_tweak": "highlight", "value": True} - ])], + type="m.room.message", + msgtype="m.text", + body="world", + push_actions=[ + (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) + ], ) yield self.replicate() yield self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2} + {"highlight_count": 1, "notify_count": 2}, ) event_id = 0 @defer.inlineCallbacks def persist( - self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, - state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, + self, + sender=USER_ID, + room_id=ROOM_ID, + type={}, + key=None, + internal={}, + state=None, + reset_state=False, + backfill=False, + depth=None, + prev_events=[], + auth_events=[], + prev_state=[], + redacts=None, push_actions=[], **content ): @@ -219,34 +234,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.event_id += 1 if state is not None: - state_ids = { - key: e.event_id for key, e in state.items() - } + state_ids = {key: e.event_id for key, e in state.items()} context = EventContext.with_state( - state_group=None, - current_state_ids=state_ids, - prev_state_ids=state_ids + state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids ) else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) yield self.master_store.add_push_actions_to_staging( - event.event_id, { - user_id: actions - for user_id, actions in push_actions - }, + event.event_id, {user_id: actions for user_id, actions in push_actions} ) ordering = None if backfill: - yield self.master_store.persist_events( - [(event, context)], backfilled=True - ) + yield self.master_store.persist_events([(event, context)], backfilled=True) else: - ordering, _ = yield self.master_store.persist_event( - event, context, - ) + ordering, _ = yield self.master_store.persist_event(event, context) if ordering: event.internal_metadata.stream_ordering = ordering diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index e6d670cc1f..ae1adeded1 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -34,6 +34,6 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} ) yield self.replicate() - yield self.check("get_receipts_for_user", [USER_ID, "m.read"], { - ROOM_ID: EVENT_ID - }) + yield self.check( + "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID} + ) |