summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/slave/storage/_base.py6
-rw-r--r--tests/replication/slave/storage/test_account_data.py14
-rw-r--r--tests/replication/slave/storage/test_events.py108
-rw-r--r--tests/replication/slave/storage/test_receipts.py6
4 files changed, 66 insertions, 68 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index a103e7be80..65df116efc 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -31,6 +31,7 @@ from tests.utils import setup_test_homeserver
 
 class TestReplicationClientHandler(ReplicationClientHandler):
     """Overrides on_rdata so that we can wait for it to happen"""
+
     def __init__(self, store):
         super(TestReplicationClientHandler, self).__init__(store)
         self._rdata_awaiters = []
@@ -53,12 +54,11 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(
+            self.addCleanup,
             "blue",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=[
-                "send_message",
-            ]),
+            ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
         self.hs.get_ratelimiter().send_message.return_value = (True, 0)
 
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index adf226404e..87cc2b2fba 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -29,20 +29,14 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
 
     @defer.inlineCallbacks
     def test_user_account_data(self):
-        yield self.master_store.add_account_data_for_user(
-            USER_ID, TYPE, {"a": 1}
-        )
+        yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
         yield self.replicate()
         yield self.check(
-            "get_global_account_data_by_type_for_user",
-            [TYPE, USER_ID], {"a": 1}
+            "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
         )
 
-        yield self.master_store.add_account_data_for_user(
-            USER_ID, TYPE, {"a": 2}
-        )
+        yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
         yield self.replicate()
         yield self.check(
-            "get_global_account_data_by_type_for_user",
-            [TYPE, USER_ID], {"a": 2}
+            "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
         )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f5b47f5ec0..2ba80ccdcf 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -38,6 +38,7 @@ def patch__eq__(cls):
     def unpatch():
         if eq is not None:
             cls.__eq__ = eq
+
     return unpatch
 
 
@@ -48,10 +49,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
     def setUp(self):
         # Patch up the equality operator for events so that we can check
         # whether lists of events match using assertEquals
-        self.unpatches = [
-            patch__eq__(_EventInternalMetadata),
-            patch__eq__(FrozenEvent),
-        ]
+        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
         return super(SlavedEventStoreTestCase, self).setUp()
 
     def tearDown(self):
@@ -61,33 +59,27 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
     def test_get_latest_event_ids_in_room(self):
         create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.replicate()
-        yield self.check(
-            "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
-        )
+        yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
 
         join = yield self.persist(
-            type="m.room.member", key=USER_ID, membership="join",
+            type="m.room.member",
+            key=USER_ID,
+            membership="join",
             prev_events=[(create.event_id, {})],
         )
         yield self.replicate()
-        yield self.check(
-            "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
-        )
+        yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
 
     @defer.inlineCallbacks
     def test_redactions(self):
         yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.persist(type="m.room.member", key=USER_ID, membership="join")
 
-        msg = yield self.persist(
-            type="m.room.message", msgtype="m.text", body="Hello"
-        )
+        msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
         yield self.replicate()
         yield self.check("get_event", [msg.event_id], msg)
 
-        redaction = yield self.persist(
-            type="m.room.redaction", redacts=msg.event_id
-        )
+        redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
         yield self.replicate()
 
         msg_dict = msg.get_dict()
@@ -102,9 +94,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.persist(type="m.room.member", key=USER_ID, membership="join")
 
-        msg = yield self.persist(
-            type="m.room.message", msgtype="m.text", body="Hello"
-        )
+        msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
         yield self.replicate()
         yield self.check("get_event", [msg.event_id], msg)
 
@@ -122,19 +112,29 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     @defer.inlineCallbacks
     def test_invites(self):
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
         event = yield self.persist(
             type="m.room.member", key=USER_ID_2, membership="invite"
         )
         yield self.replicate()
-        yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser(
-            ROOM_ID, USER_ID, "invite", event.event_id,
-            event.internal_metadata.stream_ordering
-        )])
+        yield self.check(
+            "get_invited_rooms_for_user",
+            [USER_ID_2],
+            [
+                RoomsForUser(
+                    ROOM_ID,
+                    USER_ID,
+                    "invite",
+                    event.event_id,
+                    event.internal_metadata.stream_ordering,
+                )
+            ],
+        )
 
     @defer.inlineCallbacks
     def test_push_actions_for_user(self):
-        yield self.persist(type="m.room.create", creator=USER_ID)
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.persist(type="m.room.join", key=USER_ID, membership="join")
         yield self.persist(
             type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
@@ -146,40 +146,55 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         yield self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "notify_count": 0}
+            {"highlight_count": 0, "notify_count": 0},
         )
 
         yield self.persist(
-            type="m.room.message", msgtype="m.text", body="world",
+            type="m.room.message",
+            msgtype="m.text",
+            body="world",
             push_actions=[(USER_ID_2, ["notify"])],
         )
         yield self.replicate()
         yield self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "notify_count": 1}
+            {"highlight_count": 0, "notify_count": 1},
         )
 
         yield self.persist(
-            type="m.room.message", msgtype="m.text", body="world",
-            push_actions=[(USER_ID_2, [
-                "notify", {"set_tweak": "highlight", "value": True}
-            ])],
+            type="m.room.message",
+            msgtype="m.text",
+            body="world",
+            push_actions=[
+                (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
+            ],
         )
         yield self.replicate()
         yield self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 1, "notify_count": 2}
+            {"highlight_count": 1, "notify_count": 2},
         )
 
     event_id = 0
 
     @defer.inlineCallbacks
     def persist(
-        self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
-        state=None, reset_state=False, backfill=False,
-        depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
+        self,
+        sender=USER_ID,
+        room_id=ROOM_ID,
+        type={},
+        key=None,
+        internal={},
+        state=None,
+        reset_state=False,
+        backfill=False,
+        depth=None,
+        prev_events=[],
+        auth_events=[],
+        prev_state=[],
+        redacts=None,
         push_actions=[],
         **content
     ):
@@ -219,34 +234,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.event_id += 1
 
         if state is not None:
-            state_ids = {
-                key: e.event_id for key, e in state.items()
-            }
+            state_ids = {key: e.event_id for key, e in state.items()}
             context = EventContext.with_state(
-                state_group=None,
-                current_state_ids=state_ids,
-                prev_state_ids=state_ids
+                state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
             )
         else:
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
 
         yield self.master_store.add_push_actions_to_staging(
-            event.event_id, {
-                user_id: actions
-                for user_id, actions in push_actions
-            },
+            event.event_id, {user_id: actions for user_id, actions in push_actions}
         )
 
         ordering = None
         if backfill:
-            yield self.master_store.persist_events(
-                [(event, context)], backfilled=True
-            )
+            yield self.master_store.persist_events([(event, context)], backfilled=True)
         else:
-            ordering, _ = yield self.master_store.persist_event(
-                event, context,
-            )
+            ordering, _ = yield self.master_store.persist_event(event, context)
 
         if ordering:
             event.internal_metadata.stream_ordering = ordering
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index e6d670cc1f..ae1adeded1 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -34,6 +34,6 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
             ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
         )
         yield self.replicate()
-        yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
-            ROOM_ID: EVENT_ID
-        })
+        yield self.check(
+            "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
+        )