diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
event,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..a10e5fa8b1 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
"""
persist_events_store = self.hs.get_datastores().persist_events
+ assert persist_events_store is not None
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
+ assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)]
)
)
- state1 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids1 = self.get_success(context.get_current_state_ids())
+ assert state_ids1 is not None
+ state1 = set(state_ids1.values())
event, context = self.get_success(
event_handler.create_event(
@@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)]
)
)
- state2 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids2 = self.get_success(context.get_current_state_ids())
+ assert state_ids2 is not None
+ state2 = set(state_ids2.values())
# Delete the chain cover info.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..8fc7936ab0 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ persist_events = hs.get_datastores().persist_events
+ assert persist_events is not None
+ self.persist_events = persist_events
def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self) -> None:
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key
import signedjson.types
import unpaddedbase64
-from twisted.internet.defer import Deferred
-
from synapse.storage.keys import FetchKeyResult
import tests.unittest
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
- d = store.store_server_verify_keys(
- "from_server",
- 10,
- [
- ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys(
- [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
+ res = self.get_success(
+ store.get_server_verify_keys(
+ [
+ ("server1", key_id_1),
+ ("server1", key_id_2),
+ ("server1", "ed25519:key3"),
+ ]
+ )
)
- res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_keys(
- "from_server",
- 0,
- [
- ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
- res = store.get_server_verify_keys([("srv1", key_id_1)])
- if isinstance(res, Deferred):
- res = self.successResultOf(res)
+ res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..d8f42c5d05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id, "m.room.create", ""
)
)
- self.assertIsNotNone(create_event)
+ assert create_event is not None
# Purge everything before this topological token
self.get_success(
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..12c17f1073 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persist_event_storage_controller = self.hs.get_storage_controllers().persistence
+ assert persist_event_storage_controller is not None
+ self.persist_event_storage_controller = persist_event_storage_controller
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
"content": {"msgtype": "m.text", "body": 2},
"room_id": room_id,
"sender": user_id,
- "depth": prev_event.depth + 1,
"prev_events": prev_event_ids,
"origin_server_ts": self.clock.time_msec(),
}
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_state_map,
for_verification=False,
),
- depth=event_dict["depth"],
+ depth=prev_event.depth + 1,
)
)
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
room_id=self.room_id,
from_key=self.from_token.room_key,
to_key=None,
- direction="f",
+ direction=Direction.FORWARDS,
limit=10,
event_filter=Filter(self.hs, filter),
)
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch
from synapse.storage.database import make_conn
+from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IncorrectDatabaseSetup
from tests.unittest import HomeserverTestCase
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
def test_safe_locale(self) -> None:
database = self.hs.get_datastores().databases[0]
+ assert isinstance(database.engine, PostgresEngine)
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with db_conn.cursor() as txn:
|