diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2e6e7abf1f..8ff1460c0d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
)
from synapse.logging.context import (
LoggingContext,
- PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
@@ -41,6 +41,7 @@ from synapse.storage.keys import FetchKeyResult
from tests import unittest
from tests.test_utils import make_awaitable
+from tests.unittest import logcontext_clean
class MockPerspectiveServer:
@@ -67,55 +68,42 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- self.mock_perspective_server = MockPerspectiveServer()
- self.http_client = Mock()
-
- config = self.default_config()
- config["trusted_key_servers"] = [
- {
- "server_name": self.mock_perspective_server.server_name,
- "verify_keys": self.mock_perspective_server.get_verify_keys(),
- }
- ]
-
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
-
- def check_context(self, _, expected):
+ def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
+ return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- key1 = signedjson.key.generate_signing_key(1)
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock()
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- kr = keyring.Keyring(self.hs)
+ # a signed object that we are going to try to validate
+ key1 = signedjson.key.generate_signing_key(1)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
- persp_resp = {
- "server_keys": [
- self.mock_perspective_server.get_signed_key(
- "server10", signedjson.key.get_verify_key(key1)
- )
- ]
- }
- persp_deferred = defer.Deferred()
+ # start off a first set of lookups. We make the mock fetcher block until this
+ # deferred completes.
+ first_lookup_deferred = Deferred()
+
+ async def first_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_11")
+ self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
- async def get_perspectives(**kwargs):
- self.assertEquals(current_context().request, "11")
- with PreserveLoggingContext():
- await persp_deferred
- return persp_resp
+ await make_deferred_yieldable(first_lookup_deferred)
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
- self.http_client.post_json.side_effect = get_perspectives
+ mock_fetcher.get_keys.side_effect = first_lookup_fetch
- # start off a first set of lookups
- @defer.inlineCallbacks
- def first_lookup():
- with LoggingContext("11") as context_11:
- context_11.request = "11"
+ async def first_lookup():
+ with LoggingContext("context_11") as context_11:
+ context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +112,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
- yield res_deferreds[1]
+ await res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
@@ -132,45 +120,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds[0])
+ await make_deferred_yieldable(res_deferreds[0])
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
+ d0 = ensureDeferred(first_lookup())
- d0 = first_lookup()
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- self.pump()
- self.http_client.post_json.assert_called_once()
+ mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
- @defer.inlineCallbacks
- def second_lookup():
- with LoggingContext("12") as context_12:
- context_12.request = "12"
- self.http_client.post_json.reset_mock()
- self.http_client.post_json.return_value = defer.Deferred()
+
+ async def second_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_12")
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.reset_mock()
+ mock_fetcher.get_keys.side_effect = second_lookup_fetch
+ second_lookup_state = [0]
+
+ async def second_lookup():
+ with LoggingContext("context_12") as context_12:
+ context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 1
+ await make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 2
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
-
- d2 = second_lookup()
+ d2 = ensureDeferred(second_lookup())
self.pump()
- self.http_client.post_json.assert_not_called()
+ # the second request should be pending, but the fetcher should not yet have been
+ # called
+ self.assertEqual(second_lookup_state[0], 1)
+ mock_fetcher.get_keys.assert_not_called()
# complete the first request
- persp_deferred.callback(persp_resp)
+ first_lookup_deferred.callback(None)
+
+ # and now both verifications should succeed.
self.get_success(d0)
self.get_success(d2)
@@ -317,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher2.get_keys.assert_called_once()
+@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 3d880c499d..1471cc1a28 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(None)
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
- )
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
new file mode 100644
index 0000000000..1a3ccb263d
--- /dev/null
+++ b/tests/federation/test_federation_catch_up.py
@@ -0,0 +1,422 @@
+from typing import List, Tuple
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.federation.sender import PerDestinationQueue, TransactionManager
+from synapse.federation.units import Edu
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.test_utils import event_injection, make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ state_handler = hs.get_state_handler()
+
+ # This mock is crucial for destination_rooms to be populated.
+ state_handler.get_current_hosts_in_room = Mock(
+ return_value=make_awaitable(["test", "host2"])
+ )
+
+ # whenever send_transaction is called, record the pdu data
+ self.pdus = []
+ self.failed_pdus = []
+ self.is_online = True
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ async def record_transaction(self, txn, json_cb):
+ if self.is_online:
+ data = json_cb()
+ self.pdus.extend(data["pdus"])
+ return {}
+ else:
+ data = json_cb()
+ self.failed_pdus.extend(data["pdus"])
+ raise IOError("Failed to connect because this is a test!")
+
+ def get_destination_room(self, room: str, destination: str = "host2") -> dict:
+ """
+ Gets the destination_rooms entry for a (destination, room_id) pair.
+
+ Args:
+ room: room ID
+ destination: what destination, default is "host2"
+
+ Returns:
+ Dictionary of { event_id: str, stream_ordering: int }
+ """
+ event_id, stream_ordering = self.get_success(
+ self.hs.get_datastore().db_pool.execute(
+ "test:get_destination_rooms",
+ None,
+ """
+ SELECT event_id, stream_ordering
+ FROM destination_rooms dr
+ JOIN events USING (stream_ordering)
+ WHERE dr.destination = ? AND dr.room_id = ?
+ """,
+ destination,
+ room,
+ )
+ )[0]
+ return {"event_id": event_id, "stream_ordering": stream_ordering}
+
+ @override_config({"send_federation": True})
+ def test_catch_up_destination_rooms_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"]
+
+ row_1 = self.get_destination_room(room)
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ row_2 = self.get_destination_room(room)
+
+ # check: events correctly registered in order
+ self.assertEqual(row_1["event_id"], event_id_1)
+ self.assertEqual(row_2["event_id"], event_id_2)
+ self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
+
+ @override_config({"send_federation": True})
+ def test_catch_up_last_successful_stream_ordering_tracking(self):
+ """
+ Tests that we populate the `destination_rooms` table as needed.
+ """
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+
+ # take the remote offline
+ self.is_online = False
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+ )
+
+ self.helper.send(room, "wombats!", tok=u1_token)
+ self.pump()
+
+ lsso_1 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+
+ self.assertIsNone(
+ lsso_1,
+ "There should be no last successful stream ordering for an always-offline destination",
+ )
+
+ # bring the remote online
+ self.is_online = True
+
+ event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+ lsso_2 = self.get_success(
+ self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+ "host2"
+ )
+ )
+ row_2 = self.get_destination_room(room)
+
+ self.assertEqual(
+ self.pdus[0]["content"]["body"],
+ "rabbits!",
+ "Test fault: didn't receive the right PDU",
+ )
+ self.assertEqual(
+ row_2["event_id"],
+ event_id_2,
+ "Test fault: destination_rooms not updated correctly",
+ )
+ self.assertEqual(
+ lsso_2,
+ row_2["stream_ordering"],
+ "Send succeeded but not marked as last_successful_stream_ordering",
+ )
+
+ @override_config({"send_federation": True}) # critical to federate
+ def test_catch_up_from_blank_state(self):
+ """
+ Runs an overall test of federation catch-up from scratch.
+ Further tests will focus on more narrow aspects and edge-cases, but I
+ hope to provide an overall view with this test.
+ """
+ # bring the other server online
+ self.is_online = True
+
+ # let's make some events for the other server to receive
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+ room_2 = self.helper.create_room_as("u1", tok=u1_token)
+
+ # also critical to federate
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join")
+ )
+
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "wombat"}, tok=u1_token
+ )
+
+ # check: PDU received for topic event
+ self.assertEqual(len(self.pdus), 1)
+ self.assertEqual(self.pdus[0]["type"], "m.room.topic")
+
+ # take the remote offline
+ self.is_online = False
+
+ # send another event
+ self.helper.send(room_1, "hi user!", tok=u1_token)
+
+ # check: things didn't go well since the remote is down
+ self.assertEqual(len(self.failed_pdus), 1)
+ self.assertEqual(self.failed_pdus[0]["content"]["body"], "hi user!")
+
+ # let's delete the federation transmission queue
+ # (this pretends we are starting up fresh.)
+ self.assertFalse(
+ self.hs.get_federation_sender()
+ ._per_destination_queues["host2"]
+ .transmission_loop_running
+ )
+ del self.hs.get_federation_sender()._per_destination_queues["host2"]
+
+ # let's also clear any backoffs
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0)
+ )
+
+ # bring the remote online and clear the received pdu list
+ self.is_online = True
+ self.pdus = []
+
+ # now we need to initiate a federation transaction somehow…
+ # to do that, let's send another event (because it's simple to do)
+ # (do it to another room otherwise the catch-up logic decides it doesn't
+ # need to catch up room_1 — something I overlooked when first writing
+ # this test)
+ self.helper.send(room_2, "wombats!", tok=u1_token)
+
+ # we should now have received both PDUs
+ self.assertEqual(len(self.pdus), 2)
+ self.assertEqual(self.pdus[0]["content"]["body"], "hi user!")
+ self.assertEqual(self.pdus[1]["content"]["body"], "wombats!")
+
+ def make_fake_destination_queue(
+ self, destination: str = "host2"
+ ) -> Tuple[PerDestinationQueue, List[EventBase]]:
+ """
+ Makes a fake per-destination queue.
+ """
+ transaction_manager = TransactionManager(self.hs)
+ per_dest_queue = PerDestinationQueue(self.hs, transaction_manager, destination)
+ results_list = []
+
+ async def fake_send(
+ destination_tm: str,
+ pending_pdus: List[EventBase],
+ _pending_edus: List[Edu],
+ ) -> bool:
+ assert destination == destination_tm
+ results_list.extend(pending_pdus)
+ return True # success!
+
+ transaction_manager.send_new_transaction = fake_send
+
+ return per_dest_queue, results_list
+
+ @override_config({"send_federation": True})
+ def test_catch_up_loop(self):
+ """
+ Tests the behaviour of _catch_up_transmission_loop.
+ """
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - 3 rooms which u1 is joined to (and remote user @user:host2 is
+ # joined to)
+ # - some events (1 to 5) in those rooms
+ # we have 'already sent' events 1 and 2 to host2
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+ room_2 = self.helper.create_room_as("u1", tok=u1_token)
+ room_3 = self.helper.create_room_as("u1", tok=u1_token)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_3, "@user:host2", "join")
+ )
+
+ # create some events
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+ event_id_2 = self.helper.send(room_2, "wombats!", tok=u1_token)["event_id"]
+ self.helper.send(room_3, "Matrix!", tok=u1_token)
+ event_id_4 = self.helper.send(room_2, "rabbits!", tok=u1_token)["event_id"]
+ event_id_5 = self.helper.send(room_3, "Synapse!", tok=u1_token)["event_id"]
+
+ # destination_rooms should already be populated, but let us pretend that we already
+ # sent (successfully) up to and including event id 2
+ event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2))
+
+ # also fetch event 5 so we know its last_successful_stream_ordering later
+ event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5))
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_2.internal_metadata.stream_ordering
+ )
+ )
+
+ # ACT
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # ASSERT, noticing in particular:
+ # - event 3 not sent out, because event 5 replaces it
+ # - order is least recent first, so event 5 comes after event 4
+ # - catch-up is completed
+ self.assertEqual(len(sent_pdus), 2)
+ self.assertEqual(sent_pdus[0].event_id, event_id_4)
+ self.assertEqual(sent_pdus[1].event_id, event_id_5)
+ self.assertFalse(per_dest_queue._catching_up)
+ self.assertEqual(
+ per_dest_queue._last_successful_stream_ordering,
+ event_5.internal_metadata.stream_ordering,
+ )
+
+ @override_config({"send_federation": True})
+ def test_catch_up_on_synapse_startup(self):
+ """
+ Tests the behaviour of get_catch_up_outstanding_destinations and
+ _wake_destinations_needing_catchup.
+ """
+
+ # list of sorted server names (note that there are more servers than the batch
+ # size used in get_catch_up_outstanding_destinations).
+ server_names = ["server%02d" % number for number in range(42)] + ["zzzerver"]
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - a room which u1 is joined to (and remote users @user:serverXX are
+ # joined to)
+
+ # mark the remotes as online
+ self.is_online = True
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_id = self.helper.create_room_as("u1", tok=u1_token)
+
+ for server_name in server_names:
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, "@user:%s" % server_name, "join"
+ )
+ )
+
+ # create an event
+ self.helper.send(room_id, "deary me!", tok=u1_token)
+
+ # ASSERT:
+ # - All servers are up to date so none should have outstanding catch-up
+ outstanding_when_successful = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+ self.assertEqual(outstanding_when_successful, [])
+
+ # ACT:
+ # - Make the remote servers unreachable
+ self.is_online = False
+
+ # - Mark zzzerver as being backed-off from
+ now = self.clock.time_msec()
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings(
+ "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day
+ )
+ )
+
+ # - Send an event
+ self.helper.send(room_id, "can anyone hear me?", tok=u1_token)
+
+ # ASSERT (get_catch_up_outstanding_destinations):
+ # - all remotes are outstanding
+ # - they are returned in batches of 25, in order
+ outstanding_1 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+
+ self.assertEqual(len(outstanding_1), 25)
+ self.assertEqual(outstanding_1, server_names[0:25])
+
+ outstanding_2 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(
+ outstanding_1[-1]
+ )
+ )
+ self.assertNotIn("zzzerver", outstanding_2)
+ self.assertEqual(len(outstanding_2), 17)
+ self.assertEqual(outstanding_2, server_names[25:-1])
+
+ # ACT: call _wake_destinations_needing_catchup
+
+ # patch wake_destination to just count the destinations instead
+ woken = []
+
+ def wake_destination_track(destination):
+ woken.append(destination)
+
+ self.hs.get_federation_sender().wake_destination = wake_destination_track
+
+ # cancel the pre-existing timer for _wake_destinations_needing_catchup
+ # this is because we are calling it manually rather than waiting for it
+ # to be called automatically
+ self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+
+ self.get_success(
+ self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ )
+
+ # ASSERT (_wake_destinations_needing_catchup):
+ # - all remotes are woken up, save for zzzerver
+ self.assertNotIn("zzzerver", woken)
+ # - all destinations are woken exactly once; they appear once in woken.
+ self.assertCountEqual(woken, server_names[:-1])
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 5f512ff8bf..917762e6b6 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
- mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c7efd3822d..97877c2e42 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
@@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.large_number_of_users)
+ return_value=make_awaitable(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+ return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
# Ensure does not raise exception
yield defer.ensureDeferred(
@@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.small_number_of_users)
+ return_value=make_awaitable(self.small_number_of_users)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 6aa322bf3a..969d44c787 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
+ def test_device_is_created_with_invalid_name(self):
+ self.get_failure(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="foo",
+ initial_device_display_name="a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
def test_device_is_created_if_doesnt_exist(self):
res = self.get_success(
self.handler.check_device_registered(
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 210ddcbb88..366dcfb670 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -30,7 +30,7 @@ from tests import unittest, utils
class E2eKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 3362050ce0..7adde9b9de 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -47,7 +47,7 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 89ec5fcb31..d5087e58be 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -21,7 +21,6 @@ from mock import Mock, patch
import attr
import pymacaroons
-from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
async def map_user_attributes(self, userinfo, token):
return {"localpart": userinfo["username"], "display_name": None}
+ # Do not include get_extra_attributes to test backwards compatibility paths.
+
+
+class TestMappingProviderExtra(TestMappingProvider):
+ async def get_extra_attributes(self, userinfo, token):
+ return {"phone": userinfo["phone"]}
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
@@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = config.get("oidc_config", {})
+ oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
@@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config["user_mapping_provider"] = {
"module": __name__ + ".TestMappingProvider",
}
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
@@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
- @defer.inlineCallbacks
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ metadata = self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ jwks = self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks())
+ self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- with self.assertRaises(RuntimeError):
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# This should not throw
self.handler._validate_metadata()
- @defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
- url = yield defer.ensureDeferred(
+ url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
@@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
- @defer.inlineCallbacks
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "")
request.args[b"error_description"] = [b"some description"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
- @defer.inlineCallbacks
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "foo",
"preferred_username": "bar",
}
- user_id = UserID("foo", "domain.org")
+ user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
@@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- session = self.handler._generate_oidc_session_token(
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
- request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
@@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(
raises=MappingException()
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock()
@@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
@@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
self.handler._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @defer.inlineCallbacks
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
@@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Missing cookie
request.args = {}
request.getCookie.return_value = None
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("missing_session", "No session cookie found")
# Missing session parameter
request.args = {}
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request", "State parameter is missing")
# Invalid cookie
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_session")
# Mismatching session
@@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args = {}
request.args[b"state"] = [b"mismatching state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mismatching_session")
# Valid session
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
- @defer.inlineCallbacks
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ ret = self.get_success(self.handler._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "foo")
- self.assertEqual(exc.exception.error_description, "bar")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "foo")
+ self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
@@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
self.http_client.request = simple_async_mock(
@@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "internal_server_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "internal_server_error")
+
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
@@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "some_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "some_error")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderExtra"
+ }
+ }
+ }
+ )
+ def test_extra_attributes(self):
+ """
+ Login while using a mapping provider that implements get_extra_attributes.
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "phone": "1234567",
+ }
+ user_id = "@foo:domain.org"
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
+
+ state = "state"
+ client_redirect_url = "http://client/redirect"
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+
+ request.args = {}
+ request.args[b"code"] = [b"code"]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
+ request.getClientIP.return_value = "10.0.0.1"
+
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url, {"phone": "1234567"},
+ )
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
@@ -617,3 +667,38 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
self.assertEqual(mxid, "@test_user_2:test")
+
+ # Test if the mxid is already taken
+ store = self.hs.get_datastore()
+ user3 = UserID.from_string("@test_user_3:test")
+ self.get_success(
+ store.register_user(user_id=user3.to_string(), password_hash=None)
+ )
+ userinfo = {"sub": "test3", "username": "test_user_3"}
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+
+ @override_config({"oidc_config": {"allow_existing_users": True}})
+ def test_map_userinfo_to_existing_user(self):
+ """Existing users can log in with OpenID Connect when allow_existing_users is True."""
+ store = self.hs.get_datastore()
+ user4 = UserID.from_string("@test_user_4:test")
+ self.get_success(
+ store.register_user(user_id=user4.to_string(), password_hash=None)
+ )
+ userinfo = {
+ "sub": "test4",
+ "username": "test_user_4",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_4:test")
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index eddf5e2498..cb7c0ed51a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
+ return_value=make_awaitable(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.lots_of_users)
+ return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index a609f148c0..312c0a0d41 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -54,7 +54,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms_2",
+ "update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
@@ -66,7 +66,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
@@ -219,10 +219,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
- {
- "update_name": "populate_stats_process_rooms_2",
- "progress_json": "{}",
- },
+ {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
@@ -231,7 +228,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
@@ -728,7 +725,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms_2",
+ "update_name": "populate_stats_process_rooms",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
@@ -740,7 +737,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms_2",
+ "depends_on": "populate_stats_process_rooms",
},
)
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7bf15c4ba9..3fec09ea8a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,6 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
+ "get_destination_last_successful_stream_ordering",
"get_destination_retry_timings",
"get_devices_by_remote",
"maybe_store_room_on_invite",
@@ -80,6 +81,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_user_directory_stream_pos",
"get_current_state_deltas",
"get_device_updates_by_remote",
+ "get_room_max_stream_ordering",
]
)
@@ -116,10 +118,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
+ self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
(0, [])
)
+ self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
+ None
+ )
+
def get_received_txn_response(*args):
return defer.succeed(None)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 5d41443293..3e5a856584 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory:
self._cert_file = create_test_cert_file(sanlist)
def serverConnectionForTLS(self, tlsProtocol):
- ctx = SSL.Context(SSL.TLSv1_METHOD)
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 5604af3795..212484a7fe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -318,14 +318,14 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r.code, 200)
- def test_client_headers_no_body(self):
+ @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
+ def test_timeout_reading_body(self, method_name: str):
"""
If the HTTP request is connected, but gets no response before being
- timed out, it'll give a ResponseNeverReceived.
+ timed out, it'll give a RequestSendFailed with can_retry.
"""
- d = defer.ensureDeferred(
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
- )
+ method = getattr(self.cl, method_name)
+ d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000))
self.pump()
@@ -349,7 +349,9 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.advance(10.5)
f = self.failureResultOf(d)
- self.assertIsInstance(f.value, TimeoutError)
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertTrue(f.value.can_retry)
+ self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
def test_client_requires_trailing_slashes(self):
"""
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
new file mode 100644
index 0000000000..a1cf0862d4
--- /dev/null
+++ b/tests/http/test_simple_client.py
@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from netaddr import IPSet
+
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+
+from synapse.http import RequestTimedOutError
+from synapse.http.client import SimpleHttpClient
+from synapse.server import HomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SimpleHttpClientTests(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs: "HomeServer"):
+ # Add a DNS entry for a test server
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self.cl = hs.get_simple_http_client()
+
+ def test_dns_error(self):
+ """
+ If the DNS lookup returns an error, it will bubble up.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar"))
+ self.pump()
+
+ f = self.failureResultOf(d)
+ self.assertIsInstance(f.value, DNSLookupError)
+
+ def test_client_connection_refused(self):
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+ e = Exception("go away")
+ factory.clientConnectionFailed(None, e)
+ self.pump(0.5)
+
+ f = self.failureResultOf(d)
+
+ self.assertIs(f.value, e)
+
+ def test_client_never_connect(self):
+ """
+ If the HTTP request is not connected and is timed out, it'll give a
+ ConnectingCancelledError or TimeoutError.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_connect_no_response(self):
+ """
+ If the HTTP request is connected, but gets no response before being
+ timed out, it'll give a ResponseNeverReceived.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ conn = Mock()
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_ip_range_blacklist(self):
+ """Ensure that Synapse does not try to connect to blacklisted IPs"""
+
+ # Add some DNS entries we'll blacklist
+ self.reactor.lookups["internal"] = "127.0.0.1"
+ self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
+ ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
+
+ cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
+
+ # Try making a GET request to a blacklisted IPv4 address
+ # ------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a POST request to a blacklisted IPv6 address
+ # -------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(
+ cl.post_json_get_json("http://internalv6:8008/foo/bar", {})
+ )
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ # Check that it was due to a blacklisted DNS lookup
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a GET request to a non-blacklisted IPv4 address
+ # ----------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))
+
+ # Nothing has happened yet
+ self.assertNoResult(d)
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was able to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertNotEqual(len(clients), 0)
+
+ # Connection will still fail as this IP address does not resolve to anything
+ self.failureResultOf(d, RequestTimedOutError)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 561258a356..c0ee1cfbd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -58,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super(SlavedEventStoreTestCase, self).setUp()
+ return super().setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
@@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ {(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
- self.assertEqual(
- joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
)
+ self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
event_id = 0
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8b4982ecb1..1d7edee5ba 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
- mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
@@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index faa7f381a9..92c9058887 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
request, channel = self.make_request(
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
new file mode 100644
index 0000000000..bf79086f78
--- /dev/null
+++ b/tests/rest/admin/test_event_reports.py
@@ -0,0 +1,382 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class EventReportsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self.room_id2 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
+
+ # Two rooms and two users. Every user sends and reports every room event
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.admin_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.admin_user_tok,
+ )
+
+ self.url = "/_synapse/admin/v1/event_reports"
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing list of reported events
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit(self):
+ """
+ Testing list of reported events with limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_from(self):
+ """
+ Testing list of reported events with a defined starting point (from)
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of reported events with a defined starting point and limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_filter_room(self):
+ """
+ Testing list of reported events with a filter of room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?room_id=%s" % self.room_id1,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_filter_user(self):
+ """
+ Testing list of reported events with a filter of user
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s" % self.other_user,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+
+ def test_filter_user_and_room(self):
+ """
+ Testing list of reported events with a filter of user and room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 5)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_valid_search_order(self):
+ """
+ Testing search order. Order by timestamps.
+ """
+
+ # fetch the most recent first, largest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertGreaterEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ # fetch the oldest first, smallest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertLessEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ def test_invalid_search_order(self):
+ """
+ Testing that a invalid search order returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative list parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ request, channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ for c in content:
+ self.assertIn("id", c)
+ self.assertIn("received_ts", c)
+ self.assertIn("room_id", c)
+ self.assertIn("event_id", c)
+ self.assertIn("user_id", c)
+ self.assertIn("reason", c)
+ self.assertIn("content", c)
+ self.assertIn("sender", c)
+ self.assertIn("room_alias", c)
+ self.assertIn("event_json", c)
+ self.assertIn("score", c["content"])
+ self.assertIn("reason", c["content"])
+ self.assertIn("auth_events", c["event_json"])
+ self.assertIn("type", c["event_json"])
+ self.assertIn("room_id", c["event_json"])
+ self.assertIn("sender", c["event_json"])
+ self.assertIn("content", c["event_json"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 408c568a27..6dfc709dc5 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1174,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
+ self.assertIn("topic", channel.json_body)
+ self.assertIn("avatar", channel.json_body)
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 160c630235..98d0623734 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,8 +22,8 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
-from synapse.api.errors import HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login
+from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
@@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+ return_value=make_awaitable(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._is_erased("@user:test", False)
+ d = self.store.mark_user_erased("@user:test")
+ self.assertIsNone(self.get_success(d))
+ self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
request, channel = self.make_request(
@@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
"""
@@ -995,3 +1000,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+
+ def _is_erased(self, user_id, expect):
+ """Assert that the user is erased or not
+ """
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
+
+class UserMembershipRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list rooms of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_rooms(self):
+ """
+ Tests that a normal lookup for rooms is successfully
+ """
+ # Create rooms and join
+ other_user_tok = self.login("user", "pass")
+ number_rooms = 5
+ for n in range(number_rooms):
+ self.helper.create_room_as(self.other_user, tok=other_user_tok)
+
+ # Get rooms
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_rooms, channel.json_body["total"])
+ self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2668662c9e..5d987a30c7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -7,8 +7,9 @@ from mock import Mock
import jwt
import synapse.rest.admin
+from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
@@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
)
+
+
+AS_USER = "as_user_alice"
+
+
+class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ ]
+
+ def register_as_user(self, username):
+ request, channel = self.make_request(
+ b"POST",
+ "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
+ {"username": username},
+ )
+ self.render(request)
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+
+ self.service = ApplicationService(
+ id="unique_identifier",
+ token="some_token",
+ hostname="example.com",
+ sender="@asbot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+ self.another_service = ApplicationService(
+ id="another__identifier",
+ token="another_token",
+ hostname="example.com",
+ sender="@as2bot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as2_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+
+ self.hs.get_datastore().services_cache.append(self.service)
+ self.hs.get_datastore().services_cache.append(self.another_service)
+ return self.hs
+
+ def test_login_appservice_user(self):
+ """Test that an appservice user can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_user_bot(self):
+ """Test that the appservice bot can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": self.service.sender},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_wrong_user(self):
+ """Test that non-as users cannot login with the as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_wrong_as(self):
+ """Test that as users cannot login with wrong as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.another_service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_no_token(self):
+ """Test that users must provide a token when using the appservice
+ login method
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
new file mode 100644
index 0000000000..081052f6a6
--- /dev/null
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, push_rule, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PushRuleAttributesTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ push_rule.register_servlets,
+ ]
+ hijack_auth = False
+
+ def test_enabled_on_creation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even though a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_on_recreation(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is enabled upon creation, even if a rule with that
+ ruleId existed previously and was disabled.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_disable(self):
+ """
+ Tests the GET and PUT of push rules' `enabled` endpoints.
+ Tests that a rule is disabled and enabled when we ask for it.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # disable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": False},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule disabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], False)
+
+ # re-enable the rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check rule enabled
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["enabled"], True)
+
+ def test_enabled_404_when_get_non_existent(self):
+ """
+ Tests that `enabled` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET enabled for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_enabled_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/enabled",
+ {"enabled": True},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_get(self):
+ """
+ Tests that `actions` gives you what you expect on a fresh rule.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
+ )
+
+ def test_actions_put(self):
+ """
+ Tests that PUT on actions updates the value you'd get from GET.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # change the rule actions
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # GET actions for that new rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["actions"], ["dont_notify"])
+
+ def test_actions_404_when_get_non_existent(self):
+ """
+ Tests that `actions` gives 404 when the rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ body = {
+ "conditions": [
+ {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+ ],
+ "actions": ["notify", {"set_tweak": "highlight"}],
+ }
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ # PUT a new rule
+ request, channel = self.make_request(
+ "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # DELETE the rule
+ request, channel = self.make_request(
+ "DELETE", "/pushrules/global/override/best.friend", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # check 404 for deleted rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_get_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when the server-default rule doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/best.friend/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_actions_404_when_put_non_existent_server_rule(self):
+ """
+ Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
+ """
+ self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ # enable & check 404 for never-heard-of rule
+ request, channel = self.make_request(
+ "PUT",
+ "/pushrules/global/override/.m.muahahah/actions",
+ {"actions": ["dont_notify"]},
+ access_token=token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0a567b032f..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -905,6 +905,7 @@ class RoomMessageListTestCase(RoomBase):
first_token = self.get_success(
store.get_topological_token_for_event(first_event_id)
)
+ first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before.
@@ -912,6 +913,7 @@ class RoomMessageListTestCase(RoomBase):
second_token = self.get_success(
store.get_topological_token_for_event(second_event_id)
)
+ second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room.
@@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history(
purge_id=purge_id,
room_id=self.room_id,
- token=second_token,
+ token=second_token_str,
delete_local_events=True,
)
)
@@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..ae2cd67f35 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import os
import re
from email.parser import Parser
+from typing import Optional
+from urllib.parse import urlencode
import pkg_resources
@@ -27,8 +28,10 @@ from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
+from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -69,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
"""Test basic password reset flow
@@ -250,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
+ # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False)
- self.render(request)
+ request.render(self.submit_token_resource)
+ self.pump()
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+ # password reset confirm button
+
+ # Send arguments as url-encoded form data, matching the template's behaviour
+ form_args = []
+ for key, value_list in request.args.items():
+ for value in value_list:
+ arg = (key, value)
+ form_args.append(arg)
+
+ # Confirm the password reset
+ request, channel = self.make_request(
+ "POST",
+ path,
+ content=urlencode(form_args).encode("utf8"),
+ shorthand=False,
+ content_is_form=True,
+ )
+ request.render(self.submit_token_resource)
+ self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -668,16 +696,110 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def _request_token(self, email, client_secret):
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link(self):
+ """Tests a valid next_link parameter value with no whitelist (good case)"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/good/site",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_exotic_protocol(self):
+ """Tests using a esoteric protocol as a next_link parameter value.
+ Someone may be hosting a client on IPFS etc.
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_file_uri(self):
+ """Tests next_link parameters cannot be file URI"""
+ # Attempt to use a next_link value that points to the local disk
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="file:///host/path",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+ def test_next_link_domain_whitelist(self):
+ """Tests next_link parameters must fit the whitelist if provided"""
+
+ # Ensure not providing a next_link parameter still works
+ self._request_token(
+ "something@example.com", "some_secret", next_link=None, expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/some/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.org/some/also/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://bad.example.org/some/bad/page",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": []})
+ def test_empty_next_link_domain_whitelist(self):
+ """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+ disallowed
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/page",
+ expect_code=400,
+ )
+
+ def _request_token(
+ self,
+ email: str,
+ client_secret: str,
+ next_link: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> str:
+ """Request a validation token to add an email address to a user's account
+
+ Args:
+ email: The email address to validate
+ client_secret: A secret string
+ next_link: A link to redirect the user to after validation
+ expect_code: Expected return code of the call
+
+ Returns:
+ The ID of the new threepid validation session
+ """
+ body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+ if next_link:
+ body["next_link"] = next_link
+
request, channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ "POST", b"account/3pid/email/requestToken", body,
)
self.render(request)
- self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(expect_code, channel.code, channel.result)
- return channel.json_body["sid"]
+ return channel.json_body.get("sid")
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
+ expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
- # smol png
+ # smoll png
(
_TestImage(
unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
None,
),
),
+ # an empty file
+ (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- self._test_thumbnail("crop", self.test_image.expected_cropped)
+ self._test_thumbnail(
+ "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ )
def test_thumbnail_scale(self):
- self._test_thumbnail("scale", self.test_image.expected_scaled)
+ self._test_thumbnail(
+ "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ )
- def _test_thumbnail(self, method, expected_body):
+ def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.pump()
- self.assertEqual(channel.code, 200)
- if expected_body is not None:
+ if expected_found:
+ self.assertEqual(channel.code, 200)
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
+ else:
+ # A 404 with a JSON body.
+ self.assertEqual(channel.code, 404)
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.json_body,
+ {
+ "errcode": "M_NOT_FOUND",
+ "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+ % method,
+ },
)
- else:
- # ensure that the result is at least some valid image
- Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index b090bb974c..dcd65c2a50 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -21,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self):
- super(WellKnownTests, self).setUp()
+ super().setUp()
# replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs)
diff --git a/tests/server.py b/tests/server.py
index 48e45c6c8b..b404ad4e2a 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,6 @@
import json
import logging
-from io import BytesIO
+from io import SEEK_END, BytesIO
import attr
from zope.interface import implementer
@@ -135,6 +135,7 @@ def make_request(
request=SynapseRequest,
shorthand=True,
federation_auth_origin=None,
+ content_is_form=False,
):
"""
Make a web request using the given method and path, feed it the
@@ -150,6 +151,8 @@ def make_request(
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -181,6 +184,8 @@ def make_request(
req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ # Twisted expects to be at the end of the content when parsing the request.
+ req.content.seek(SEEK_END)
req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
@@ -195,7 +200,13 @@ def make_request(
)
if content:
- req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+ if content_is_form:
+ req.requestHeaders.addRawHeader(
+ b"Content-Type", b"application/x-www-form-urlencoded"
+ )
+ else:
+ # Assume the body is JSON
+ req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
req.requestReceived(method, path, b"1.1")
@@ -249,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return succeed(lookups[name])
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
- super(ThreadedMemoryReactorClock, self).__init__()
+ super().__init__()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 973338ea71..6382b19dc3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
self._rlsn._store.user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(1000)
+ return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
@@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_tags_for_room = Mock(
- side_effect=lambda user_id, room_id: make_awaitable({})
- )
+ self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(None)
+ return_value=make_awaitable(None)
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(1000)
- )
+ self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
- side_effect=lambda user_id: make_awaitable(1000)
+ return_value=make_awaitable(1000)
)
# Call the function multiple times to ensure we only send the notice once
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index cb808d4de4..46f94914ff 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -413,7 +413,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TestTransactionStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 40ba652248..eac7e4dcd2 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -56,6 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
+ fake_engine.in_transaction.return_value = False
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 370c247e16..755c70db31 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
- side_effect=lambda: make_awaitable(lots_of_users)
+ return_value=make_awaitable(lots_of_users)
)
self.get_success(
self.store.insert_client_ip(
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 34ae8c9da7..ecb00f4e02 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -23,7 +23,7 @@ import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 949846fe33..3957471f3f 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
self.reactor.advance(60 * 60 * 1000)
self.pump(1)
- items = set(
+ items = list(
filter(
lambda x: b"synapse_forward_extremities_" in x,
- generate_latest(REGISTRY).split(b"\n"),
+ generate_latest(REGISTRY, emit_help=False).split(b"\n"),
)
)
- expected = {
+ expected = [
b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
@@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
- b"synapse_forward_extremities_count 3.0",
- b"synapse_forward_extremities_sum 10.0",
- }
-
+ # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9,
+ # "inf" is valid: "this includes variants such as inf"
+ b'synapse_forward_extremities_bucket{le="inf"} 3.0',
+ b"# TYPE synapse_forward_extremities_gcount gauge",
+ b"synapse_forward_extremities_gcount 3.0",
+ b"# TYPE synapse_forward_extremities_gsum gauge",
+ b"synapse_forward_extremities_gsum 10.0",
+ ]
self.assertEqual(items, expected)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index f0a8e32f1e..392b08832b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from synapse.storage.database import DatabasePool
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -43,19 +42,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
)
- return self.get_success(self.db_pool.runWithConnection(_create))
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
"""Insert N rows as the given instance, inserting with stream IDs pulled
@@ -68,6 +71,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@@ -81,6 +91,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id, stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
@@ -111,7 +128,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await id_gen.get_next() as stream_id:
+ async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
@@ -122,6 +139,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+ def test_out_of_order_finish(self):
+ """Test that IDs persisted out of order are correctly handled
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+ ctx3 = self.get_success(id_gen.get_next())
+ ctx4 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+ s3 = self.get_success(ctx3.__aenter__())
+ s4 = self.get_success(ctx4.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+ self.assertEqual(s3, 10)
+ self.assertEqual(s4, 11)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ self.get_success(ctx4.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+ self.get_success(ctx3.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 11})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
+
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
correctly.
@@ -129,8 +196,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_rows("first", 3)
self._insert_rows("second", 4)
- first_id_gen = self._create_id_generator("first")
- second_id_gen = self._create_id_generator("second")
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
@@ -140,7 +207,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await first_id_gen.get_next() as stream_id:
+ async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
@@ -158,7 +225,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID
async def _get_next_async():
- with await second_id_gen.get_next() as stream_id:
+ async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
@@ -212,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -250,14 +317,18 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
- with self.get_success(id_gen.get_next()) as stream_id:
- self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
@@ -265,6 +336,115 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+ def test_restart_during_out_of_order_persistence(self):
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Persist two rows at once
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
+
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+ def test_writer_config_change(self):
+ """Test that changing the writer config correctly works.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ # Initial config has two writers
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
+
+ # New config removes one of the configs. Note that if the writer is
+ # removed from config we assume that it has been shut down and has
+ # finished persisting, hence why the persisted upto position is 5.
+ id_gen_2 = self._create_id_generator("second", writers=["second"])
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+ self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5)
+
+ # This config points to a single, previously unused writer.
+ id_gen_3 = self._create_id_generator("third", writers=["third"])
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer comes along.
+ self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5)
+
+ id_gen_4 = self._create_id_generator("fourth", writers=["third"])
+ self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5)
+
+ # Check that we get a sane next stream ID with this new config.
+
+ async def _get_next_async():
+ async with id_gen_3.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+
+ self.get_success(_get_next_async())
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
+ # If we add back the old "first" then we shouldn't see the persisted up
+ # to position revert back to 3.
+ id_gen_5 = self._create_id_generator("five", writers=["first", "third"])
+ self.assertEqual(id_gen_5.get_persisted_upto_position(), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
+
+ def test_sequence_consistency(self):
+ """Test that we error out if the table and sequence diverges.
+ """
+
+ # Prefill with some rows
+ self._insert_row_with_id("master", 3)
+
+ # Now we add a row *without* updating the stream ID
+ def _insert(txn):
+ txn.execute("INSERT INTO foobar VALUES (26, 'master')")
+
+ self.get_success(self.db_pool.runInteraction("_insert", _insert))
+
+ # Creating the ID gen should error
+ with self.assertRaises(IncorrectDatabaseSetup):
+ self._create_id_generator("first")
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
@@ -291,16 +471,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
positive=False,
)
@@ -314,6 +498,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
@@ -323,16 +514,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
id_gen = self._create_id_generator()
- with self.get_success(id_gen.get_next()) as stream_id:
- self._insert_row("master", stream_id)
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
- for stream_id in stream_ids:
- self._insert_row("master", stream_id)
+ async def _get_next_async2():
+ async with id_gen.get_next_mult(3) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
@@ -349,21 +546,27 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
- id_gen_1 = self._create_id_generator("first")
- id_gen_2 = self._create_id_generator("second")
+ id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+ id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
- with self.get_success(id_gen_1.get_next()) as stream_id:
- self._insert_row("first", stream_id)
- id_gen_2.advance("first", stream_id)
+ async def _get_next_async():
+ async with id_gen_1.get_next() as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- with self.get_success(id_gen_2.get_next()) as stream_id:
- self._insert_row("second", stream_id)
- id_gen_1.advance("second", stream_id)
+ async def _get_next_async2():
+ async with id_gen_2.get_next() as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9870c74883..8d97b6d4cd 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
+ def test_appservice_user_not_counted_in_mau(self):
+ self.get_success(
+ self.store.register_user(
+ user_id="@appservice_user:server", appservice_id="wibble"
+ )
+ )
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.upsert_monthly_active_user("@appservice_user:server")
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
@@ -231,9 +246,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -241,9 +254,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -256,9 +267,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -344,9 +353,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock(
- side_effect=lambda user_id: make_awaitable(None)
- )
+ self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
@@ -391,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, 4)
+ self.assertEqual(count, 1)
d = self.store.get_monthly_active_count_by_service()
result = self.get_success(d)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 918387733b..cc1f3c53c5 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = self.get_success(
+ token = self.get_success(
store.get_topological_token_for_event(last["event_id"])
)
+ token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
- self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
+ self.get_success(
+ storage.purge_events.purge_history(self.room_id, token_str, True)
+ )
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
@@ -74,12 +77,10 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = self.get_success(
+ token = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
)
- event = "t{}-{}".format(
- *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
- )
+ event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
diff --git a/tests/test_state.py b/tests/test_state.py
index 2d58467932..80b0ccbc40 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -125,7 +125,7 @@ class StateGroupStore:
class DictObj(dict):
def __init__(self, **kwargs):
- super(DictObj, self).__init__(kwargs)
+ super().__init__(kwargs)
self.__dict__ = self
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 508aeba078..a298cc0fd3 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,6 +17,7 @@
"""
Utilities for running the unit tests
"""
+from asyncio import Future
from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-async def make_awaitable(result: Any):
- """Create an awaitable that just returns a result."""
- return result
+def make_awaitable(result: Any) -> Awaitable[Any]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future = Future() # type: ignore
+ future.set_result(result)
+ return future
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index fb1ca90336..e93aa84405 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -71,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- await hs.get_storage().persistence.persist_event(event, context)
+ persistence = hs.get_storage().persistence
+ assert persistence is not None
+
+ await persistence.persist_event(event, context)
return event
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 2d96b0fa8d..fdfb840b62 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -29,8 +29,7 @@ class ToTwistedHandler(logging.Handler):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
+ twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 3cb55a7e96..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import hashlib
import hmac
@@ -23,11 +22,12 @@ import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union
-from mock import Mock
+from mock import Mock, patch
from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -92,7 +92,7 @@ class TestCase(unittest.TestCase):
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
- super(TestCase, self).__init__(methodName, *args, **kwargs)
+ super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
@@ -169,6 +169,19 @@ def INFO(target):
return target
+def logcontext_clean(target):
+ """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+ ... ie, any logcontext errors should cause a test failure
+ """
+
+ def logcontext_error(msg):
+ raise AssertionError("logcontext error: %s" % (msg))
+
+ patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+ return patcher(target)
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -353,6 +366,7 @@ class HomeserverTestCase(TestCase):
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
+ content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -368,6 +382,8 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -384,6 +400,7 @@ class HomeserverTestCase(TestCase):
request,
shorthand,
federation_auth_origin,
+ content_is_form,
)
def render(self, request):
@@ -459,6 +476,35 @@ class HomeserverTestCase(TestCase):
self.pump()
return self.failureResultOf(d, exc)
+ def get_success_or_raise(self, d, by=0.0):
+ """Drive deferred to completion and return result or raise exception
+ on failure.
+ """
+
+ if inspect.isawaitable(d):
+ deferred = ensureDeferred(d)
+ if not isinstance(deferred, Deferred):
+ return d
+
+ results = [] # type: list
+ deferred.addBoth(results.append)
+
+ self.pump(by=by)
+
+ if not results:
+ self.fail(
+ "Success result expected on {!r}, found no result instead".format(
+ deferred
+ )
+ )
+
+ result = results[0]
+
+ if isinstance(result, Failure):
+ result.raiseException()
+
+ return result
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
|