diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3ac4646969..74c6224eb6 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""
# First to acquire this lock, so it should complete
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock)
+ assert lock is not None
# Enter the context manager
self.get_success(lock.__aenter__())
@@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# We can now acquire the lock again.
lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock3)
+ assert lock3 is not None
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))
@@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock)
+ assert lock is not None
self.get_success(lock.__aenter__())
@@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
- self.assertIsNotNone(lock)
+ assert lock is not None
self.get_success(lock.__aenter__())
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ee599f4336..1bf93e79a7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
+from synapse.types import DeviceListUpdates
from synapse.util import Clock
from tests import unittest
@@ -168,15 +169,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(as_id, txn_id, json.dumps([e.event_id for e in events])),
)
- def _set_last_txn(self, as_id, txn_id):
- return self.db_pool.runOperation(
- self.engine.convert_param_style(
- "INSERT INTO application_services_state(as_id, last_txn, state) "
- "VALUES(?,?,?)"
- ),
- (as_id, txn_id, ApplicationServiceState.UP.value),
- )
-
def test_get_appservice_state_none(
self,
) -> None:
@@ -267,65 +259,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
+ self.store.create_appservice_txn(
+ service, events, [], [], {}, {}, DeviceListUpdates()
+ )
)
)
self.assertEqual(txn.id, 1)
self.assertEqual(txn.events, events)
self.assertEqual(txn.service, service)
- def test_create_appservice_txn_older_last_txn(
- self,
- ) -> None:
- service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
- self.get_success(self._insert_txn(service.id, 9644, events))
- self.get_success(self._insert_txn(service.id, 9645, events))
- txn = self.get_success(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
- )
- self.assertEqual(txn.id, 9646)
- self.assertEqual(txn.events, events)
- self.assertEqual(txn.service, service)
-
- def test_create_appservice_txn_up_to_date_last_txn(
- self,
- ) -> None:
- service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643))
- txn = self.get_success(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
- )
- self.assertEqual(txn.id, 9644)
- self.assertEqual(txn.events, events)
- self.assertEqual(txn.service, service)
-
- def test_create_appservice_txn_up_fuzzing(
- self,
- ) -> None:
- service = Mock(id=self.as_list[0]["id"])
- events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
- self.get_success(self._set_last_txn(service.id, 9643))
-
- # dump in rows with higher IDs to make sure the queries aren't wrong.
- self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643))
- self.get_success(self._set_last_txn(self.as_list[2]["id"], 9))
- self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events))
- self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events))
- self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
- self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
-
- txn = self.get_success(
- self.store.create_appservice_txn(service, events, [], [], {}, {})
- )
- self.assertEqual(txn.id, 9644)
- self.assertEqual(txn.events, events)
- self.assertEqual(txn.service, service)
-
def test_complete_appservice_txn_first_txn(
self,
) -> None:
@@ -359,13 +301,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(0, len(res))
- def test_complete_appservice_txn_existing_in_state_table(
+ def test_complete_appservice_txn_updates_last_txn_state(
self,
) -> None:
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn_id = 5
- self.get_success(self._set_last_txn(service.id, 4))
+ self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
self.get_success(self._insert_txn(service.id, txn_id, events))
self.get_success(
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
@@ -416,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 12, other_events))
txn = self.get_success(self.store.get_oldest_unsent_txn(service))
+ assert txn is not None
self.assertEqual(service, txn.service)
self.assertEqual(10, txn.id)
self.assertEqual(events, txn.events)
@@ -476,12 +419,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
)
- self.assertEqual(value, 0)
+ self.assertEqual(value, 1)
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "presence")
)
- self.assertEqual(value, 0)
+ self.assertEqual(value, 1)
def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
self.get_failure(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index ce89c96912..b998ad42d9 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -68,6 +68,22 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.wait_for_background_updates()
+ def add_extremity(self, room_id: str, event_id: str) -> None:
+ """
+ Add the given event as an extremity to the room.
+ """
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_insert(
+ table="event_forward_extremities",
+ values={"room_id": room_id, "event_id": event_id},
+ desc="test_add_extremity",
+ )
+ )
+
+ self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
+ (room_id,)
+ )
+
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -250,7 +266,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = create_requester(self.user)
- info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(
+ self.room_creator.create_room(self.requester, {"visibility": "public"})
+ )
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add two device updates with sequential `stream_id`s
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get all device updates ever meant for this remote
@@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
"device_id5",
]
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get device updates meant for this remote
@@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add some more device updates to ensure it still resumes properly
device_ids = ["device_id6", "device_id7"]
self.get_success(
- self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+ self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"], ["!some:room"]
+ )
)
# Get the next batch of device updates
@@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
self.get_success(
self.store.add_device_change_to_streams(
- "@user_id:test", device_ids, ["somehost"]
+ "@user_id:test", device_ids, ["somehost"], ["!some:room"]
)
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 395396340b..2d8d1f860f 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx1 = self.get_success(id_gen.get_next())
- ctx2 = self.get_success(id_gen.get_next())
- ctx3 = self.get_success(id_gen.get_next())
- ctx4 = self.get_success(id_gen.get_next())
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+ ctx4 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
@@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once
- ctx1 = self.get_success(id_gen.get_next())
- ctx2 = self.get_success(id_gen.get_next())
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__())
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 03e9cc7d4a..d8d17ef379 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
return event
def test_redact(self):
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+ msg_event = self.inject_message(self.room1, self.u_alice, "t")
# Check event has not been redacted:
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event
reason = "Because I said so"
- self.get_success(
- self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
- )
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
def test_redact_join(self):
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(
- self.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
- )
+ msg_event = self.inject_room_member(
+ self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
)
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event
reason = "Because I said so"
- self.get_success(
- self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
- )
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
# Check redaction
@@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_redact_censor(self):
"""Test that a redacted event gets censored in the DB after a month"""
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+ msg_event = self.inject_message(self.room1, self.u_alice, "t")
# Check event has not been redacted:
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event
reason = "Because I said so"
- self.get_success(
- self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
- )
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
event = self.get_success(self.store.get_event(msg_event.event_id))
@@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_redact_redaction(self):
"""Tests that we can redact a redaction and can fetch it again."""
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+ msg_event = self.inject_message(self.room1, self.u_alice, "t")
- first_redact_event = self.get_success(
- self.inject_redaction(
- self.room1, msg_event.event_id, self.u_alice, "Redacting message"
- )
+ first_redact_event = self.inject_redaction(
+ self.room1, msg_event.event_id, self.u_alice, "Redacting message"
)
- self.get_success(
- self.inject_redaction(
- self.room1,
- first_redact_event.event_id,
- self.u_alice,
- "Redacting redaction",
- )
+ self.inject_redaction(
+ self.room1,
+ first_redact_event.event_id,
+ self.u_alice,
+ "Redacting redaction",
)
# Now lets jump to the future where we have censored the redaction event
@@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_store_redacted_redaction(self):
"""Tests that we can store a redacted redaction."""
- self.get_success(
- self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
- )
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index eaa0d7d749..52e41cdab4 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase):
def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
"""Make a request to /messages with a filter, returns the chunk of events."""
- from_token = self.get_success(
- self.hs.get_event_sources().get_current_token_for_pagination()
- )
+ from_token = self.hs.get_event_sources().get_current_token_for_pagination()
events, next_key = self.get_success(
self.hs.get_datastores().main.paginate_room_events(
|