summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_events_worker.py1
-rw-r--r--tests/storage/test_event_chain.py10
-rw-r--r--tests/storage/test_event_federation.py9
-rw-r--r--tests/storage/test_events.py8
-rw-r--r--tests/storage/test_keys.py61
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_receipts.py6
-rw-r--r--tests/storage/test_room_search.py3
-rw-r--r--tests/storage/test_stream.py4
-rw-r--r--tests/storage/test_unsafe_locale.py2
10 files changed, 63 insertions, 43 deletions
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # Persist the event which should invalidate or prefill the
         # `have_seen_event` cache so we don't return stale values.
         persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
         self.get_success(
             persistence.persist_event(
                 event,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..a10e5fa8b1 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         """
 
         persist_events_store = self.hs.get_datastores().persist_events
+        assert persist_events_store is not None
 
         for e in events:
             e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         def _persist(txn: LoggingTransaction) -> None:
             # We need to persist the events to the events and state_events
             # tables.
+            assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
                 [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 self.requester, events_and_context=[(event, context)]
             )
         )
-        state1 = set(self.get_success(context.get_current_state_ids()).values())
+        state_ids1 = self.get_success(context.get_current_state_ids())
+        assert state_ids1 is not None
+        state1 = set(state_ids1.values())
 
         event, context = self.get_success(
             event_handler.create_event(
@@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 self.requester, events_and_context=[(event, context)]
             )
         )
-        state2 = set(self.get_success(context.get_current_state_ids()).values())
+        state_ids2 = self.get_success(context.get_current_state_ids())
+        assert state_ids2 is not None
+        state2 = set(state_ids2.values())
 
         # Delete the chain cover info.
 
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..8fc7936ab0 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
+        persist_events = hs.get_datastores().persist_events
+        assert persist_events is not None
+        self.persist_events = persist_events
 
     def test_get_prev_events_for_room(self) -> None:
         room_id = "@ROOM:local"
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                     },
                 )
 
-            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+            self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
                     cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 )
 
             # Insert all events apart from 'B'
-            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+            self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
                     cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": False},
             )
 
-            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+            self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
             )
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
         self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
     ) -> None:
         self.state = self.hs.get_state_handler()
-        self._persistence = self.hs.get_storage_controllers().persistence
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persistence = persistence
         self._state_storage_controller = self.hs.get_storage_controllers().state
         self.store = self.hs.get_datastores().main
 
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
         self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
     ) -> None:
         self.state = self.hs.get_state_handler()
-        self._persistence = self.hs.get_storage_controllers().persistence
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persistence = persistence
         self.store = self.hs.get_datastores().main
 
     def test_remote_user_rooms_cache_invalidated(self) -> None:
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key
 import signedjson.types
 import unpaddedbase64
 
-from twisted.internet.defer import Deferred
-
 from synapse.storage.keys import FetchKeyResult
 
 import tests.unittest
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:KEY_ID_2"
-        d = store.store_server_verify_keys(
-            "from_server",
-            10,
-            [
-                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
-                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
-            ],
+        self.get_success(
+            store.store_server_verify_keys(
+                "from_server",
+                10,
+                [
+                    ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                    ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+                ],
+            )
         )
-        self.get_success(d)
 
-        d = store.get_server_verify_keys(
-            [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
+        res = self.get_success(
+            store.get_server_verify_keys(
+                [
+                    ("server1", key_id_1),
+                    ("server1", key_id_2),
+                    ("server1", "ed25519:key3"),
+                ]
+            )
         )
-        res = self.get_success(d)
 
         self.assertEqual(len(res.keys()), 3)
         res1 = res[("server1", key_id_1)]
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:key2"
 
-        d = store.store_server_verify_keys(
-            "from_server",
-            0,
-            [
-                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
-                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
-            ],
+        self.get_success(
+            store.store_server_verify_keys(
+                "from_server",
+                0,
+                [
+                    ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                    ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+                ],
+            )
         )
-        self.get_success(d)
 
-        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        res = self.get_success(d)
+        res = self.get_success(
+            store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+        )
         self.assertEqual(len(res.keys()), 2)
 
         res1 = res[("srv1", key_id_1)]
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
-        res = store.get_server_verify_keys([("srv1", key_id_1)])
-        if isinstance(res, Deferred):
-            res = self.successResultOf(res)
+        res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
         self.assertEqual(len(res.keys()), 1)
         self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        res = self.get_success(d)
+        res = self.get_success(
+            store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+        )
         self.assertEqual(len(res.keys()), 2)
 
         res1 = res[("srv1", key_id_1)]
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..d8f42c5d05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase):
                 self.room_id, "m.room.create", ""
             )
         )
-        self.assertIsNotNone(create_event)
+        assert create_event is not None
 
         # Purge everything before this topological token
         self.get_success(
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..12c17f1073 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
         self.store = homeserver.get_datastores().main
 
         self.room_creator = homeserver.get_room_creation_handler()
-        self.persist_event_storage_controller = (
-            self.hs.get_storage_controllers().persistence
-        )
+        persist_event_storage_controller = self.hs.get_storage_controllers().persistence
+        assert persist_event_storage_controller is not None
+        self.persist_event_storage_controller = persist_event_storage_controller
 
         # Create a test user
         self.ourUser = UserID.from_string(OUR_USER_ID)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
             "content": {"msgtype": "m.text", "body": 2},
             "room_id": room_id,
             "sender": user_id,
-            "depth": prev_event.depth + 1,
             "prev_events": prev_event_ids,
             "origin_server_ts": self.clock.time_msec(),
         }
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
                     prev_state_map,
                     for_verification=False,
                 ),
-                depth=event_dict["depth"],
+                depth=prev_event.depth + 1,
             )
         )
 
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
 from synapse.api.filtering import Filter
 from synapse.rest import admin
 from synapse.rest.client import login, room
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
                 room_id=self.room_id,
                 from_key=self.from_token.room_key,
                 to_key=None,
-                direction="f",
+                direction=Direction.FORWARDS,
                 limit=10,
                 event_filter=Filter(self.hs, filter),
             )
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@
 from unittest.mock import MagicMock, patch
 
 from synapse.storage.database import make_conn
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IncorrectDatabaseSetup
 
 from tests.unittest import HomeserverTestCase
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
 
     def test_safe_locale(self) -> None:
         database = self.hs.get_datastores().databases[0]
+        assert isinstance(database.engine, PostgresEngine)
 
         db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
         with db_conn.cursor() as txn: