diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 01d257307c..875b0d0a11 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -302,11 +302,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
)
# Check that the expected presence updates were sent
- expected_users = [
+ # We explicitly compare using sets as we expect that calling
+ # module_api.send_local_online_presence_to will create a presence
+ # update that is a duplicate of the specified user's current presence.
+ # These are sent to clients and will be picked up below, thus we use a
+ # set to deduplicate. We're just interested that non-offline updates were
+ # sent out for each user ID.
+ expected_users = {
self.other_user_id,
self.presence_receiving_user_one_id,
self.presence_receiving_user_two_id,
- ]
+ }
+ found_users = set()
calls = (
self.hs.get_federation_transport_client().send_transaction.call_args_list
@@ -326,12 +333,12 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
# EDUs can contain multiple presence updates
for presence_update in edu["content"]["push"]:
# Check for presence updates that contain the user IDs we're after
- expected_users.remove(presence_update["user_id"])
+ found_users.add(presence_update["user_id"])
# Ensure that no offline states are being sent out
self.assertNotEqual(presence_update["presence"], "offline")
- self.assertEqual(len(expected_users), 0)
+ self.assertEqual(found_users, expected_users)
def send_presence_update(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 1ffab709fc..d90a9fec91 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -32,13 +32,19 @@ from synapse.handlers.presence import (
handle_timeout,
handle_update,
)
+from synapse.rest import admin
from synapse.rest.client.v1 import room
from synapse.types import UserID, get_domain_from_id
from tests import unittest
-class PresenceUpdateTestCase(unittest.TestCase):
+class PresenceUpdateTestCase(unittest.HomeserverTestCase):
+ servlets = [admin.register_servlets]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+
def test_offline_to_online(self):
wheel_timer = Mock()
user_id = "@foo:bar"
@@ -292,6 +298,45 @@ class PresenceUpdateTestCase(unittest.TestCase):
any_order=True,
)
+ def test_persisting_presence_updates(self):
+ """Tests that the latest presence state for each user is persisted correctly"""
+ # Create some test users and presence states for them
+ presence_states = []
+ for i in range(5):
+ user_id = self.register_user(f"user_{i}", "password")
+
+ presence_state = UserPresenceState(
+ user_id=user_id,
+ state="online",
+ last_active_ts=1,
+ last_federation_update_ts=1,
+ last_user_sync_ts=1,
+ status_msg="I'm online!",
+ currently_active=True,
+ )
+ presence_states.append(presence_state)
+
+ # Persist these presence updates to the database
+ self.get_success(self.store.update_presence(presence_states))
+
+ # Check that each update is present in the database
+ db_presence_states = self.get_success(
+ self.store.get_all_presence_updates(
+ instance_name="master",
+ last_id=0,
+ current_id=len(presence_states) + 1,
+ limit=len(presence_states),
+ )
+ )
+
+ # Extract presence update user ID and state information into lists of tuples
+ db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
+ presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+
+ # Compare what we put into the storage with what we got out.
+ # They should be identical.
+ self.assertEqual(presence_states, db_presence_states)
+
class PresenceTimeoutTestCase(unittest.TestCase):
def test_idle_timer(self):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 0c89487eaf..f58afbc244 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -89,14 +89,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source = hs.get_event_sources().sources["typing"]
self.datastore = hs.get_datastore()
- retry_timings_res = {
- "destination": "",
- "retry_last_ts": 0,
- "retry_interval": 0,
- "failure_ts": None,
- }
self.datastore.get_destination_retry_timings = Mock(
- return_value=defer.succeed(retry_timings_res)
+ return_value=defer.succeed(None)
)
self.datastore.get_device_updates_by_remote = Mock(
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 742ad14b8c..2c68b9a13c 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -13,6 +13,8 @@
# limitations under the License.
from unittest.mock import Mock
+from twisted.internet import defer
+
from synapse.api.constants import EduTypes
from synapse.events import EventBase
from synapse.federation.units import Transaction
@@ -22,11 +24,13 @@ from synapse.rest.client.v1 import login, presence, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils.event_injection import inject_member_event
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import HomeserverTestCase, override_config
+from tests.utils import USE_POSTGRES_FOR_TESTS
-class ModuleApiTestCase(FederatingHomeserverTestCase):
+class ModuleApiTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -217,97 +221,16 @@ class ModuleApiTestCase(FederatingHomeserverTestCase):
)
self.assertFalse(is_in_public_rooms)
- # The ability to send federation is required by send_local_online_presence_to.
- @override_config({"send_federation": True})
def test_send_local_online_presence_to(self):
- """Tests that send_local_presence_to_users sends local online presence to local users."""
- # Create a user who will send presence updates
- self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
- self.presence_receiver_tok = self.login("presence_receiver", "monkey")
-
- # And another user that will send presence updates out
- self.presence_sender_id = self.register_user("presence_sender", "monkey")
- self.presence_sender_tok = self.login("presence_sender", "monkey")
-
- # Put them in a room together so they will receive each other's presence updates
- room_id = self.helper.create_room_as(
- self.presence_receiver_id,
- tok=self.presence_receiver_tok,
- )
- self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
-
- # Presence sender comes online
- send_presence_update(
- self,
- self.presence_sender_id,
- self.presence_sender_tok,
- "online",
- "I'm online!",
- )
-
- # Presence receiver should have received it
- presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
- self.assertEqual(len(presence_updates), 1)
-
- presence_update = presence_updates[0] # type: UserPresenceState
- self.assertEqual(presence_update.user_id, self.presence_sender_id)
- self.assertEqual(presence_update.state, "online")
-
- # Syncing again should result in no presence updates
- presence_updates, sync_token = sync_presence(
- self, self.presence_receiver_id, sync_token
- )
- self.assertEqual(len(presence_updates), 0)
-
- # Trigger sending local online presence
- self.get_success(
- self.module_api.send_local_online_presence_to(
- [
- self.presence_receiver_id,
- ]
- )
- )
-
- # Presence receiver should have received online presence again
- presence_updates, sync_token = sync_presence(
- self, self.presence_receiver_id, sync_token
- )
- self.assertEqual(len(presence_updates), 1)
-
- presence_update = presence_updates[0] # type: UserPresenceState
- self.assertEqual(presence_update.user_id, self.presence_sender_id)
- self.assertEqual(presence_update.state, "online")
-
- # Presence sender goes offline
- send_presence_update(
- self,
- self.presence_sender_id,
- self.presence_sender_tok,
- "offline",
- "I slink back into the darkness.",
- )
-
- # Trigger sending local online presence
- self.get_success(
- self.module_api.send_local_online_presence_to(
- [
- self.presence_receiver_id,
- ]
- )
- )
-
- # Presence receiver should *not* have received offline state
- presence_updates, sync_token = sync_presence(
- self, self.presence_receiver_id, sync_token
- )
- self.assertEqual(len(presence_updates), 0)
+ # Test sending local online presence to users from the main process
+ _test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
@override_config({"send_federation": True})
def test_send_local_online_presence_to_federation(self):
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
- self.presence_sender_id = self.register_user("presence_sender", "monkey")
- self.presence_sender_tok = self.login("presence_sender", "monkey")
+ self.presence_sender_id = self.register_user("presence_sender1", "monkey")
+ self.presence_sender_tok = self.login("presence_sender1", "monkey")
# And a room they're a part of
room_id = self.helper.create_room_as(
@@ -374,3 +297,209 @@ class ModuleApiTestCase(FederatingHomeserverTestCase):
found_update = True
self.assertTrue(found_update)
+
+
+class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+ """For testing ModuleApi functionality in a multi-worker setup"""
+
+ # Testing stream ID replication from the main to worker processes requires postgres
+ # (due to needing `MultiWriterIdGenerator`).
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ presence.register_servlets,
+ ]
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["redis"] = {"enabled": "true"}
+ conf["stream_writers"] = {"presence": ["presence_writer"]}
+ conf["instance_map"] = {
+ "presence_writer": {"host": "testserv", "port": 1001},
+ }
+ return conf
+
+ def prepare(self, reactor, clock, homeserver):
+ self.module_api = homeserver.get_module_api()
+ self.sync_handler = homeserver.get_sync_handler()
+
+ def test_send_local_online_presence_to_workers(self):
+ # Test sending local online presence to users from a worker process
+ _test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
+
+
+def _test_sending_local_online_presence_to_local_user(
+ test_case: HomeserverTestCase, test_with_workers: bool = False
+):
+ """Tests that send_local_presence_to_users sends local online presence to local users.
+
+ This simultaneously tests two different usecases:
+ * Testing that this method works when either called from a worker or the main process.
+ - We test this by calling this method from both a TestCase that runs in monolith mode, and one that
+ runs with a main and generic_worker.
+ * Testing that multiple devices syncing simultaneously will all receive a snapshot of local,
+ online presence - but only once per device.
+
+ Args:
+ test_with_workers: If True, this method will call ModuleApi.send_local_online_presence_to on a
+ worker process. The test users will still sync with the main process. The purpose of testing
+ with a worker is to check whether a Synapse module running on a worker can inform other workers/
+ the main process that they should include additional presence when a user next syncs.
+ """
+ if test_with_workers:
+ # Create a worker process to make module_api calls against
+ worker_hs = test_case.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+ )
+
+ # Create a user who will send presence updates
+ test_case.presence_receiver_id = test_case.register_user(
+ "presence_receiver1", "monkey"
+ )
+ test_case.presence_receiver_tok = test_case.login("presence_receiver1", "monkey")
+
+ # And another user that will send presence updates out
+ test_case.presence_sender_id = test_case.register_user("presence_sender2", "monkey")
+ test_case.presence_sender_tok = test_case.login("presence_sender2", "monkey")
+
+ # Put them in a room together so they will receive each other's presence updates
+ room_id = test_case.helper.create_room_as(
+ test_case.presence_receiver_id,
+ tok=test_case.presence_receiver_tok,
+ )
+ test_case.helper.join(
+ room_id, test_case.presence_sender_id, tok=test_case.presence_sender_tok
+ )
+
+ # Presence sender comes online
+ send_presence_update(
+ test_case,
+ test_case.presence_sender_id,
+ test_case.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Presence receiver should have received it
+ presence_updates, sync_token = sync_presence(
+ test_case, test_case.presence_receiver_id
+ )
+ test_case.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
+ test_case.assertEqual(presence_update.state, "online")
+
+ if test_with_workers:
+ # Replicate the current sync presence token from the main process to the worker process.
+ # We need to do this so that the worker process knows the current presence stream ID to
+ # insert into the database when we call ModuleApi.send_local_online_presence_to.
+ test_case.replicate()
+
+ # Syncing again should result in no presence updates
+ presence_updates, sync_token = sync_presence(
+ test_case, test_case.presence_receiver_id, sync_token
+ )
+ test_case.assertEqual(len(presence_updates), 0)
+
+ # We do an (initial) sync with a second "device" now, getting a new sync token.
+ # We'll use this in a moment.
+ _, sync_token_second_device = sync_presence(
+ test_case, test_case.presence_receiver_id
+ )
+
+ # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
+ if test_with_workers:
+ module_api_to_use = worker_hs.get_module_api()
+ else:
+ module_api_to_use = test_case.module_api
+
+ # Trigger sending local online presence. We expect this information
+ # to be saved to the database where all processes can access it.
+ # Note that we're syncing via the master.
+ d = module_api_to_use.send_local_online_presence_to(
+ [
+ test_case.presence_receiver_id,
+ ]
+ )
+ d = defer.ensureDeferred(d)
+
+ if test_with_workers:
+ # In order for the required presence_set_state replication request to occur between the
+ # worker and main process, we need to pump the reactor. Otherwise, the coordinator that
+ # reads the request on the main process won't do so, and the request will time out.
+ while not d.called:
+ test_case.reactor.advance(0.1)
+
+ test_case.get_success(d)
+
+ # The presence receiver should have received online presence again.
+ presence_updates, sync_token = sync_presence(
+ test_case, test_case.presence_receiver_id, sync_token
+ )
+ test_case.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
+ test_case.assertEqual(presence_update.state, "online")
+
+ # We attempt to sync with the second sync token we received above - just to check that
+ # multiple syncing devices will each receive the necessary online presence.
+ presence_updates, sync_token_second_device = sync_presence(
+ test_case, test_case.presence_receiver_id, sync_token_second_device
+ )
+ test_case.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
+ test_case.assertEqual(presence_update.state, "online")
+
+ # However, if we now sync with either "device", we won't receive another burst of online presence
+ # until the API is called again sometime in the future
+ presence_updates, sync_token = sync_presence(
+ test_case, test_case.presence_receiver_id, sync_token
+ )
+
+ # Now we check that we don't receive *offline* updates using ModuleApi.send_local_online_presence_to.
+
+ # Presence sender goes offline
+ send_presence_update(
+ test_case,
+ test_case.presence_sender_id,
+ test_case.presence_sender_tok,
+ "offline",
+ "I slink back into the darkness.",
+ )
+
+ # Presence receiver should have received the updated, offline state
+ presence_updates, sync_token = sync_presence(
+ test_case, test_case.presence_receiver_id, sync_token
+ )
+ test_case.assertEqual(len(presence_updates), 1)
+
+ # Now trigger sending local online presence.
+ d = module_api_to_use.send_local_online_presence_to(
+ [
+ test_case.presence_receiver_id,
+ ]
+ )
+ d = defer.ensureDeferred(d)
+
+ if test_with_workers:
+ # In order for the required presence_set_state replication request to occur between the
+ # worker and main process, we need to pump the reactor. Otherwise, the coordinator that
+ # reads the request on the main process won't do so, and the request will time out.
+ while not d.called:
+ test_case.reactor.advance(0.1)
+
+ test_case.get_success(d)
+
+ # Presence receiver should *not* have received offline state
+ presence_updates, sync_token = sync_presence(
+ test_case, test_case.presence_receiver_id, sync_token
+ )
+ test_case.assertEqual(len(presence_updates), 0)
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index d739eb6b17..5eca5c165d 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -30,7 +30,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works"""
# Event persister sharding requires postgres (due to needing
- # `MutliWriterIdGenerator`).
+ # `MultiWriterIdGenerator`).
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index b7f7eae8d0..bea9091d30 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.storage.databases.main.transactions import DestinationRetryTimings
from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase
@@ -36,8 +37,11 @@ class TransactionStoreTestCase(HomeserverTestCase):
d = self.store.get_destination_retry_timings("example.com")
r = self.get_success(d)
- self.assert_dict(
- {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r
+ self.assertEqual(
+ DestinationRetryTimings(
+ retry_last_ts=50, retry_interval=100, failure_ts=1000
+ ),
+ r,
)
def test_initial_set_transactions(self):
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 178ac8a68c..bbbc276697 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
with LoggingContext("c1") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
+
+ # start the lookup off
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1
self.assertEqual(current_context(), c1)
- obj.mock.assert_called_once_with([10, 20], 2)
+ obj.mock.assert_called_once_with((10, 20), 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = {30: "peas"}
r = yield obj.list_fn([20, 30], 2)
- obj.mock.assert_called_once_with([30], 2)
+ obj.mock.assert_called_once_with((30,), 2)
self.assertEqual(r, {20: "chips", 30: "peas"})
obj.mock.reset_mock()
@@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
+ # we should also be able to use a (single-use) iterable, and should
+ # deduplicate the keys
+ obj.mock.reset_mock()
+ obj.mock.return_value = {40: "gravy"}
+ iterable = (x for x in [10, 40, 40])
+ r = yield obj.list_fn(iterable, 2)
+ obj.mock.assert_called_once_with((40,), 2)
+ self.assertEqual(r, {10: "fish", 40: "gravy"})
+
@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
@@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
# cache miss
obj.mock.return_value = {10: "fish", 20: "chips"}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
- obj.mock.assert_called_once_with([10, 20], 2)
+ obj.mock.assert_called_once_with((10, 20), 2)
self.assertEqual(r1, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
new file mode 100644
index 0000000000..5def1e56c9
--- /dev/null
+++ b/tests/util/test_batching_queue.py
@@ -0,0 +1,169 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util.batching_queue import BatchingQueue
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class BatchingQueueTestCase(TestCase):
+ def setUp(self):
+ self.clock, hs_clock = get_clock()
+
+ self._pending_calls = []
+ self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
+
+ async def _process_queue(self, values):
+ d = defer.Deferred()
+ self._pending_calls.append((values, d))
+ return await make_deferred_yieldable(d)
+
+ def test_simple(self):
+ """Tests the basic case of calling `add_to_queue` once and having
+ `_process_queue` return.
+ """
+
+ self.assertFalse(self._pending_calls)
+
+ queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
+
+ # The queue should wait a reactor tick before calling the processing
+ # function.
+ self.assertFalse(self._pending_calls)
+ self.assertFalse(queue_d.called)
+
+ # We should see a call to `_process_queue` after a reactor tick.
+ self.clock.pump([0])
+
+ self.assertEqual(len(self._pending_calls), 1)
+ self.assertEqual(self._pending_calls[0][0], ["foo"])
+ self.assertFalse(queue_d.called)
+
+ # Return value of the `_process_queue` should be propagated back.
+ self._pending_calls.pop()[1].callback("bar")
+
+ self.assertEqual(self.successResultOf(queue_d), "bar")
+
+ def test_batching(self):
+ """Test that multiple calls at the same time get batched up into one
+ call to `_process_queue`.
+ """
+
+ self.assertFalse(self._pending_calls)
+
+ queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
+ queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
+
+ self.clock.pump([0])
+
+ # We should see only *one* call to `_process_queue`
+ self.assertEqual(len(self._pending_calls), 1)
+ self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
+ self.assertFalse(queue_d1.called)
+ self.assertFalse(queue_d2.called)
+
+ # Return value of the `_process_queue` should be propagated back to both.
+ self._pending_calls.pop()[1].callback("bar")
+
+ self.assertEqual(self.successResultOf(queue_d1), "bar")
+ self.assertEqual(self.successResultOf(queue_d2), "bar")
+
+ def test_queuing(self):
+ """Test that we queue up requests while a `_process_queue` is being
+ called.
+ """
+
+ self.assertFalse(self._pending_calls)
+
+ queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
+ self.clock.pump([0])
+
+ queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
+
+ # We should see only *one* call to `_process_queue`
+ self.assertEqual(len(self._pending_calls), 1)
+ self.assertEqual(self._pending_calls[0][0], ["foo1"])
+ self.assertFalse(queue_d1.called)
+ self.assertFalse(queue_d2.called)
+
+ # Return value of the `_process_queue` should be propagated back to the
+ # first.
+ self._pending_calls.pop()[1].callback("bar1")
+
+ self.assertEqual(self.successResultOf(queue_d1), "bar1")
+ self.assertFalse(queue_d2.called)
+
+ # We should now see a second call to `_process_queue`
+ self.clock.pump([0])
+ self.assertEqual(len(self._pending_calls), 1)
+ self.assertEqual(self._pending_calls[0][0], ["foo2"])
+ self.assertFalse(queue_d2.called)
+
+ # Return value of the `_process_queue` should be propagated back to the
+ # second.
+ self._pending_calls.pop()[1].callback("bar2")
+
+ self.assertEqual(self.successResultOf(queue_d2), "bar2")
+
+ def test_different_keys(self):
+ """Test that calls to different keys get processed in parallel."""
+
+ self.assertFalse(self._pending_calls)
+
+ queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
+ self.clock.pump([0])
+ queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
+ self.clock.pump([0])
+
+ # We queue up another item with key=2 to check that we will keep taking
+ # things off the queue.
+ queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2))
+
+ # We should see two calls to `_process_queue`
+ self.assertEqual(len(self._pending_calls), 2)
+ self.assertEqual(self._pending_calls[0][0], ["foo1"])
+ self.assertEqual(self._pending_calls[1][0], ["foo2"])
+ self.assertFalse(queue_d1.called)
+ self.assertFalse(queue_d2.called)
+ self.assertFalse(queue_d3.called)
+
+ # Return value of the `_process_queue` should be propagated back to the
+ # first.
+ self._pending_calls.pop(0)[1].callback("bar1")
+
+ self.assertEqual(self.successResultOf(queue_d1), "bar1")
+ self.assertFalse(queue_d2.called)
+ self.assertFalse(queue_d3.called)
+
+ # Return value of the `_process_queue` should be propagated back to the
+ # second.
+ self._pending_calls.pop()[1].callback("bar2")
+
+ self.assertEqual(self.successResultOf(queue_d2), "bar2")
+ self.assertFalse(queue_d3.called)
+
+ # We should now see a call `_pending_calls` for `foo3`
+ self.clock.pump([0])
+ self.assertEqual(len(self._pending_calls), 1)
+ self.assertEqual(self._pending_calls[0][0], ["foo3"])
+ self.assertFalse(queue_d3.called)
+
+ # Return value of the `_process_queue` should be propagated back to the
+ # third deferred.
+ self._pending_calls.pop()[1].callback("bar4")
+
+ self.assertEqual(self.successResultOf(queue_d3), "bar4")
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 1bd0b45d94..e712eb42ea 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List
+from typing import Dict, Iterable, List, Sequence
from synapse.util.iterutils import chunk_seq, sorted_topologically
@@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
)
def test_empty_input(self):
- parts = chunk_seq([], 5)
+ parts = chunk_seq([], 5) # type: Iterable[Sequence]
self.assertEqual(
list(parts),
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index df3e27779f..377904e72e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache.pop("key"), None)
def test_del_multi(self):
- cache = LruCache(4, keylen=2, cache_type=TreeCache)
+ cache = LruCache(4, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -165,7 +165,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m2 = Mock()
m3 = Mock()
m4 = Mock()
- cache = LruCache(4, keylen=2, cache_type=TreeCache)
+ cache = LruCache(4, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9b2be83a43..9e1bebdc83 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -51,10 +51,12 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
+ self.pump()
+
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
- self.assertEqual(new_timings["failure_ts"], failure_ts)
- self.assertEqual(new_timings["retry_last_ts"], failure_ts)
- self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
+ self.assertEqual(new_timings.failure_ts, failure_ts)
+ self.assertEqual(new_timings.retry_last_ts, failure_ts)
+ self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
# now if we try again we should get a failure
self.get_failure(
@@ -77,14 +79,16 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
+ self.pump()
+
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
- self.assertEqual(new_timings["failure_ts"], failure_ts)
- self.assertEqual(new_timings["retry_last_ts"], retry_ts)
+ self.assertEqual(new_timings.failure_ts, failure_ts)
+ self.assertEqual(new_timings.retry_last_ts, retry_ts)
self.assertGreaterEqual(
- new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
+ new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
)
self.assertLessEqual(
- new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
+ new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
)
#
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 3b077af27e..6066372053 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from .. import unittest
@@ -64,12 +64,14 @@ class TreeCacheTestCase(unittest.TestCase):
cache[("a", "b")] = "AB"
cache[("b", "a")] = "BA"
self.assertEquals(cache.get(("a", "a")), "AA")
- cache.pop(("a",))
+ popped = cache.pop(("a",))
self.assertEquals(cache.get(("a", "a")), None)
self.assertEquals(cache.get(("a", "b")), None)
self.assertEquals(cache.get(("b", "a")), "BA")
self.assertEquals(len(cache), 1)
+ self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
+
def test_clear(self):
cache = TreeCache()
cache[("a",)] = "A"
|