summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/storage/test_account_data.py4
-rw-r--r--tests/storage/test_background_update.py4
-rw-r--r--tests/storage/test_cleanup_extrems.py3
-rw-r--r--tests/storage/test_client_ips.py24
-rw-r--r--tests/storage/test_event_chain.py23
-rw-r--r--tests/storage/test_event_federation.py17
-rw-r--r--tests/storage/test_event_push_actions.py4
-rw-r--r--tests/storage/test_events.py6
-rw-r--r--tests/storage/test_id_generators.py30
-rw-r--r--tests/storage/test_monthly_active_users.py2
-rw-r--r--tests/storage/test_redaction.py9
-rw-r--r--tests/storage/test_registration.py11
12 files changed, 90 insertions, 47 deletions
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 673e1fe3e3..38444e48e2 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -96,7 +96,9 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         # No ignored_users key.
         self.get_success(
             self.store.add_account_data_for_user(
-                self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+                self.user,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {},
             )
         )
 
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 02aae1c13d..1b4fae0bb5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -67,7 +67,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         async def update(progress, count):
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
-                count, target_background_update_duration_ms / duration_ms, places=0,
+                count,
+                target_background_update_duration_ms / duration_ms,
+                places=0,
             )
             await self.updates._end_background_update("test_update")
             return count
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index c13a57dad1..7791138688 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         self.room_id = info["room_id"]
 
     def run_background_update(self):
-        """Re run the background update to clean up the extremities.
-        """
+        """Re run the background update to clean up the extremities."""
         # Make sure we don't clash with in progress updates.
         self.assertTrue(
             self.store.db_pool.updates._all_done, "Background updates are still ongoing"
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index a69117c5a9..34e6526097 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -41,7 +41,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         device_id = "MY_DEVICE"
 
         # Insert a user IP
-        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", device_id
@@ -214,7 +220,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         device_id = "MY_DEVICE"
 
         # Insert a user IP
-        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", device_id
@@ -303,7 +315,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         device_id = "MY_DEVICE"
 
         # Insert a user IP
-        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", device_id
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 0c46ad595b..16daa66cc9 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -90,7 +90,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
                     "content": {"tag": "power"},
                 },
             ).build(
-                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id],
             )
         )
 
@@ -226,7 +227,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
 
             self.assertFalse(
                 link_map.exists_path_from(
-                    chain_map[create.event_id], chain_map[event.event_id],
+                    chain_map[create.event_id],
+                    chain_map[event.event_id],
                 ),
             )
 
@@ -287,7 +289,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
                     "content": {"tag": "power"},
                 },
             ).build(
-                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id],
             )
         )
 
@@ -373,7 +376,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
             )
 
     def persist(
-        self, events: List[EventBase],
+        self,
+        events: List[EventBase],
     ):
         """Persist the given events and check that the links generated match
         those given.
@@ -394,7 +398,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             persist_events_store._persist_event_auth_chain_txn(txn, events)
 
         self.get_success(
-            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+            persist_events_store.db_pool.runInteraction(
+                "_persist",
+                _persist,
+            )
         )
 
     def fetch_chains(
@@ -447,8 +454,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
 
 class LinkMapTestCase(unittest.TestCase):
     def test_simple(self):
-        """Basic tests for the LinkMap.
-        """
+        """Basic tests for the LinkMap."""
         link_map = _LinkMap()
 
         link_map.add_link((1, 1), (2, 1), new=False)
@@ -490,8 +496,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         self.requester = create_requester(self.user_id)
 
     def _generate_room(self) -> Tuple[str, List[Set[str]]]:
-        """Insert a room without a chain cover index.
-        """
+        """Insert a room without a chain cover index."""
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
         # Mark the room as not having a chain cover index
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9d04a066d8..06000f81a6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -215,7 +215,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 ],
             )
 
-        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "insert",
+                insert_event,
+            )
+        )
 
         # Now actually test that various combinations give the right result:
 
@@ -370,7 +375,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
 
             self.hs.datastores.persist_events._persist_event_auth_chain_txn(
-                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+                txn,
+                [FakeEvent("b", room_id, auth_graph["b"])],
             )
 
             self.store.db_pool.simple_update_txn(
@@ -380,7 +386,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": True},
             )
 
-        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "insert",
+                insert_event,
+            )
+        )
 
         # Now actually test that various combinations give the right result:
 
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index c0595963dd..485f1ee033 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -84,7 +84,9 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
             yield defer.ensureDeferred(
                 self.store.add_push_actions_to_staging(
-                    event.event_id, {user_id: action}, False,
+                    event.event_id,
+                    {user_id: action},
+                    False,
                 )
             )
             yield defer.ensureDeferred(
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 71210ce606..ed898b8dbb 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -68,16 +68,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
         self.assert_extremities([self.remote_event_1.event_id])
 
     def persist_event(self, event, state=None):
-        """Persist the event, with optional state
-        """
+        """Persist the event, with optional state"""
         context = self.get_success(
             self.state.compute_event_context(event, old_state=state)
         )
         self.get_success(self.persistence.persist_event(event, context))
 
     def assert_extremities(self, expected_extremities):
-        """Assert the current extremities for the room
-        """
+        """Assert the current extremities for the room"""
         extremities = self.get_success(
             self.store.get_prev_events_for_room(self.room_id)
         )
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 3e2fd4da01..aad6bc907e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -86,7 +86,11 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         def _insert(txn):
             txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+                "INSERT INTO foobar VALUES (?, ?)",
+                (
+                    stream_id,
+                    instance_name,
+                ),
             )
             txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
             txn.execute(
@@ -138,8 +142,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
     def test_out_of_order_finish(self):
-        """Test that IDs persisted out of order are correctly handled
-        """
+        """Test that IDs persisted out of order are correctly handled"""
 
         # Prefill table with 7 rows written by 'master'
         self._insert_rows("master", 7)
@@ -246,8 +249,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
     def test_get_next_txn(self):
-        """Test that the `get_next_txn` function works correctly.
-        """
+        """Test that the `get_next_txn` function works correctly."""
 
         # Prefill table with 7 rows written by 'master'
         self._insert_rows("master", 7)
@@ -386,8 +388,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
 
     def test_writer_config_change(self):
-        """Test that changing the writer config correctly works.
-        """
+        """Test that changing the writer config correctly works."""
 
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
@@ -434,8 +435,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
 
     def test_sequence_consistency(self):
-        """Test that we error out if the table and sequence diverges.
-        """
+        """Test that we error out if the table and sequence diverges."""
 
         # Prefill with some rows
         self._insert_row_with_id("master", 3)
@@ -452,8 +452,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
-    """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
-    """
+    """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
 
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
@@ -494,12 +493,15 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         return self.get_success(self.db_pool.runWithConnection(_create))
 
     def _insert_row(self, instance_name: str, stream_id: int):
-        """Insert one row as the given instance with given stream_id.
-        """
+        """Insert one row as the given instance with given stream_id."""
 
         def _insert(txn):
             txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+                "INSERT INTO foobar VALUES (?, ?)",
+                (
+                    stream_id,
+                    instance_name,
+                ),
             )
             txn.execute(
                 """
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 8d97b6d4cd..5858c7fcc4 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -198,7 +198,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
     # value, although it gets stored on the config object as mau_limits.
     @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
     def test_reap_monthly_active_users_reserved_users(self):
-        """ Tests that reaping correctly handles reaping where reserved users are
+        """Tests that reaping correctly handles reaping where reserved users are
         present"""
         threepids = self.hs.config.mau_limits_reserved_threepids
         initial_users = len(threepids)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index a6303bf0ee..b2a0e60856 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -299,8 +299,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
     def test_redact_censor(self):
-        """Test that a redacted event gets censored in the DB after a month
-        """
+        """Test that a redacted event gets censored in the DB after a month"""
 
         self.get_success(
             self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -370,8 +369,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.assert_dict({"content": {}}, json.loads(event_json))
 
     def test_redact_redaction(self):
-        """Tests that we can redact a redaction and can fetch it again.
-        """
+        """Tests that we can redact a redaction and can fetch it again."""
 
         self.get_success(
             self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -404,8 +402,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
     def test_store_redacted_redaction(self):
-        """Tests that we can store a redacted redaction.
-        """
+        """Tests that we can store a redacted redaction."""
 
         self.get_success(
             self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c8c7a90e5d..4eb41c46e8 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -52,6 +52,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "creation_ts": 1000,
                 "user_type": None,
                 "deactivated": 0,
+                "shadow_banned": 0,
             },
             (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
         )
@@ -145,7 +146,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
         try:
             yield defer.ensureDeferred(
                 self.store.validate_threepid_session(
-                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                    "fake_sid",
+                    "fake_client_secret",
+                    "fake_token",
+                    0,
                 )
             )
         except ThreepidValidationError as e:
@@ -158,7 +162,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
         try:
             yield defer.ensureDeferred(
                 self.store.validate_threepid_session(
-                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                    "fake_sid",
+                    "fake_client_secret",
+                    "fake_token",
+                    0,
                 )
             )
         except ThreepidValidationError as e: