summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_lock.py8
-rw-r--r--tests/storage/test_appservice.py75
-rw-r--r--tests/storage/test_cleanup_extrems.py20
-rw-r--r--tests/storage/test_devices.py14
-rw-r--r--tests/storage/test_id_generators.py12
-rw-r--r--tests/storage/test_redaction.py62
-rw-r--r--tests/storage/test_stream.py4
7 files changed, 69 insertions, 126 deletions
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py

index 3ac4646969..74c6224eb6 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py
@@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase): """ # First to acquire this lock, so it should complete lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None # Enter the context manager self.get_success(lock.__aenter__()) @@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase): # We can now acquire the lock again. lock3 = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock3) + assert lock3 is not None self.get_success(lock3.__aenter__()) self.get_success(lock3.__aexit__(None, None, None)) @@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase): """Test that we don't time out locks while they're still active""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None self.get_success(lock.__aenter__()) @@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase): """Test that we time out locks if they're not updated for ages""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None self.get_success(lock.__aenter__()) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ee599f4336..1bf93e79a7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -168,15 +169,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (as_id, txn_id, json.dumps([e.event_id for e in events])), ) - def _set_last_txn(self, as_id, txn_id): - return self.db_pool.runOperation( - self.engine.convert_param_style( - "INSERT INTO application_services_state(as_id, last_txn, state) " - "VALUES(?,?,?)" - ), - (as_id, txn_id, ApplicationServiceState.UP.value), - ) - def test_get_appservice_state_none( self, ) -> None: @@ -267,65 +259,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) ) self.assertEqual(txn.id, 1) self.assertEqual(txn.events, events) self.assertEqual(txn.service, service) - def test_create_appservice_txn_older_last_txn( - self, - ) -> None: - service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind - self.get_success(self._insert_txn(service.id, 9644, events)) - self.get_success(self._insert_txn(service.id, 9645, events)) - txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) - ) - self.assertEqual(txn.id, 9646) - self.assertEqual(txn.events, events) - self.assertEqual(txn.service, service) - - def test_create_appservice_txn_up_to_date_last_txn( - self, - ) -> None: - service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) - txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) - ) - self.assertEqual(txn.id, 9644) - self.assertEqual(txn.events, events) - self.assertEqual(txn.service, service) - - def test_create_appservice_txn_up_fuzzing( - self, - ) -> None: - service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) - - # dump in rows with higher IDs to make sure the queries aren't wrong. - self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643)) - self.get_success(self._set_last_txn(self.as_list[2]["id"], 9)) - self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events)) - self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) - self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) - - txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) - ) - self.assertEqual(txn.id, 9644) - self.assertEqual(txn.events, events) - self.assertEqual(txn.service, service) - def test_complete_appservice_txn_first_txn( self, ) -> None: @@ -359,13 +301,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) self.assertEqual(0, len(res)) - def test_complete_appservice_txn_existing_in_state_table( + def test_complete_appservice_txn_updates_last_txn_state( self, ) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 - self.get_success(self._set_last_txn(service.id, 4)) + self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) self.get_success(self._insert_txn(service.id, txn_id, events)) self.get_success( self.store.complete_appservice_txn(txn_id=txn_id, service=service) @@ -416,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + assert txn is not None self.assertEqual(service, txn.service) self.assertEqual(10, txn.id) self.assertEqual(events, txn.events) @@ -476,12 +419,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEqual(value, 0) + self.assertEqual(value, 1) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEqual(value, 0) + self.assertEqual(value, 1) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index ce89c96912..b998ad42d9 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -68,6 +68,22 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.wait_for_background_updates() + def add_extremity(self, room_id: str, event_id: str) -> None: + """ + Add the given event as an extremity to the room. + """ + self.get_success( + self.hs.get_datastores().main.db_pool.simple_insert( + table="event_forward_extremities", + values={"room_id": room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) + def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of soft failed events. @@ -250,7 +266,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = create_requester(self.user) - info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success( + self.room_creator.create_room(self.requester, {"visibility": "public"}) + ) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.consent.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add two device updates with sequential `stream_id`s self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get all device updates ever meant for this remote @@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase): "device_id5", ] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get device updates meant for this remote @@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add some more device updates to ensure it still resumes properly device_ids = ["device_id6", "device_id7"] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get the next batch of device updates @@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase): self.get_success( self.store.add_device_change_to_streams( - "@user_id:test", device_ids, ["somehost"] + "@user_id:test", device_ids, ["somehost"], ["!some:room"] ) ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 395396340b..2d8d1f860f 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1 = self.get_success(id_gen.get_next()) - ctx2 = self.get_success(id_gen.get_next()) - ctx3 = self.get_success(id_gen.get_next()) - ctx4 = self.get_success(id_gen.get_next()) + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + ctx4 = id_gen.get_next() s1 = self.get_success(ctx1.__aenter__()) s2 = self.get_success(ctx2.__aenter__()) @@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Persist two rows at once - ctx1 = self.get_success(id_gen.get_next()) - ctx2 = self.get_success(id_gen.get_next()) + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() s1 = self.get_success(ctx1.__aenter__()) s2 = self.get_success(ctx2.__aenter__()) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 03e9cc7d4a..d8d17ef379 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): return event def test_redact(self): - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + msg_event = self.inject_message(self.room1, self.u_alice, "t") # Check event has not been redacted: event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): # Redact event reason = "Because I said so" - self.get_success( - self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) - ) + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) def test_redact_join(self): - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success( - self.inject_room_member( - self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} - ) + msg_event = self.inject_room_member( + self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} ) event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): # Redact event reason = "Because I said so" - self.get_success( - self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) - ) + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) # Check redaction @@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): def test_redact_censor(self): """Test that a redacted event gets censored in the DB after a month""" - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + msg_event = self.inject_message(self.room1, self.u_alice, "t") # Check event has not been redacted: event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): # Redact event reason = "Because I said so" - self.get_success( - self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) - ) + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase): def test_redact_redaction(self): """Tests that we can redact a redaction and can fetch it again.""" - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + msg_event = self.inject_message(self.room1, self.u_alice, "t") - first_redact_event = self.get_success( - self.inject_redaction( - self.room1, msg_event.event_id, self.u_alice, "Redacting message" - ) + first_redact_event = self.inject_redaction( + self.room1, msg_event.event_id, self.u_alice, "Redacting message" ) - self.get_success( - self.inject_redaction( - self.room1, - first_redact_event.event_id, - self.u_alice, - "Redacting redaction", - ) + self.inject_redaction( + self.room1, + first_redact_event.event_id, + self.u_alice, + "Redacting redaction", ) # Now lets jump to the future where we have censored the redaction event @@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def test_store_redacted_redaction(self): """Tests that we can store a redacted redaction.""" - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index eaa0d7d749..52e41cdab4 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py
@@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase): def _filter_messages(self, filter: JsonDict) -> List[EventBase]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination() - ) + from_token = self.hs.get_event_sources().get_current_token_for_pagination() events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events(