From 693516e7566fdde677f5d836d45c44fbd8722ba9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 15:21:00 +0000 Subject: Add `create_resource_dict` method to HomeserverTestCase Rather than using a single JsonResource, construct a resource tree, as we do in the prod code, and allow testcases to add extra resources by overriding `create_resource_dict`. --- tests/unittest.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index a9d59e31f7..425b39b1d1 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Optional, Tuple, Type, TypeVar, Union, overload +from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload from mock import Mock, patch @@ -46,6 +46,7 @@ from synapse.logging.context import ( ) from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver @@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase): """ Create a the root resource for the test server. - The default implementation creates a JsonResource and calls each function in - `servlets` to register servletes against it + The default calls `self.create_resource_dict` and builds the resultant dict + into a tree. """ - resource = JsonResource(self.hs) + root_resource = Resource() + create_resource_tree(self.create_resource_dict(), root_resource) + return root_resource - for servlet in self.servlets: - servlet(self.hs, resource) + def create_resource_dict(self) -> Dict[str, Resource]: + """Create a resource tree for the test server + + A resource tree is a mapping from path to twisted.web.resource. - return resource + The default implementation creates a JsonResource and calls each function in + `servlets` to register servlets against it. + """ + servlet_resource = JsonResource(self.hs) + for servlet in self.servlets: + servlet(self.hs, servlet_resource) + return { + "/_matrix/client": servlet_resource, + "/_synapse/admin": servlet_resource, + } def default_config(self): """ -- cgit 1.5.1 From c5b6abd53d1549c7a7cc30c644dd8921bfc10ea2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Dec 2020 15:22:37 +0000 Subject: Correctly handle unpersisted events when calculating auth chain difference. (#8827) We do state res with unpersisted events when calculating the new current state of the room, so that should be the only thing impacted. I don't think this is tooooo big of a deal as: 1. the next time a state event happens in the room the current state should correct itself; 2. in the common case all the unpersisted events' auth events will be pulled in by other state, so will still return the correct result (or one which is sufficiently close to not affect the result); and 3. we mostly use the state at an event to do important operations, which isn't affected by this. --- changelog.d/8827.bugfix | 1 + synapse/state/v2.py | 87 ++++++++++++++++++++-- tests/state/test_v2.py | 128 ++++++++++++++++++++++++++++++++- tests/storage/test_event_federation.py | 5 ++ 4 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8827.bugfix (limited to 'tests') diff --git a/changelog.d/8827.bugfix b/changelog.d/8827.bugfix new file mode 100644 index 0000000000..18195680d3 --- /dev/null +++ b/changelog.d/8827.bugfix @@ -0,0 +1 @@ +Fix bug where we might not correctly calculate the current state for rooms with multiple extremities. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index f57df0d728..ffc504ce77 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -252,9 +252,88 @@ async def _get_auth_chain_difference( Set of event IDs """ - difference = await state_res_store.get_auth_chain_difference( - [set(state_set.values()) for state_set in state_sets] - ) + # The `StateResolutionStore.get_auth_chain_difference` function assumes that + # all events passed to it (and their auth chains) have been persisted + # previously. This is not the case for any events in the `event_map`, and so + # we need to manually handle those events. + # + # We do this by: + # 1. calculating the auth chain difference for the state sets based on the + # events in `event_map` alone + # 2. replacing any events in the state_sets that are also in `event_map` + # with their auth events (recursively), and then calling + # `store.get_auth_chain_difference` as normal + # 3. adding the results of 1 and 2 together. + + # Map from event ID in `event_map` to their auth event IDs, and their auth + # event IDs if they appear in the `event_map`. This is the intersection of + # the event's auth chain with the events in the `event_map` *plus* their + # auth event IDs. + events_to_auth_chain = {} # type: Dict[str, Set[str]] + for event in event_map.values(): + chain = {event.event_id} + events_to_auth_chain[event.event_id] = chain + + to_search = [event] + while to_search: + for auth_id in to_search.pop().auth_event_ids(): + chain.add(auth_id) + auth_event = event_map.get(auth_id) + if auth_event: + to_search.append(auth_event) + + # We now a) calculate the auth chain difference for the unpersisted events + # and b) work out the state sets to pass to the store. + # + # Note: If the `event_map` is empty (which is the common case), we can do a + # much simpler calculation. + if event_map: + # The list of state sets to pass to the store, where each state set is a set + # of the event ids making up the state. This is similar to `state_sets`, + # except that (a) we only have event ids, not the complete + # ((type, state_key)->event_id) mappings; and (b) we have stripped out + # unpersisted events and replaced them with the persisted events in + # their auth chain. + state_sets_ids = [] # type: List[Set[str]] + + # For each state set, the unpersisted event IDs reachable (by their auth + # chain) from the events in that set. + unpersisted_set_ids = [] # type: List[Set[str]] + + for state_set in state_sets: + set_ids = set() # type: Set[str] + state_sets_ids.append(set_ids) + + unpersisted_ids = set() # type: Set[str] + unpersisted_set_ids.append(unpersisted_ids) + + for event_id in state_set.values(): + event_chain = events_to_auth_chain.get(event_id) + if event_chain is not None: + # We have an event in `event_map`. We add all the auth + # events that it references (that aren't also in `event_map`). + set_ids.update(e for e in event_chain if e not in event_map) + + # We also add the full chain of unpersisted event IDs + # referenced by this state set, so that we can work out the + # auth chain difference of the unpersisted events. + unpersisted_ids.update(e for e in event_chain if e in event_map) + else: + set_ids.add(event_id) + + # The auth chain difference of the unpersisted events of the state sets + # is calculated by taking the difference between the union and + # intersections. + union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) + intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) + + difference_from_event_map = union - intersection # type: Collection[str] + else: + difference_from_event_map = () + state_sets_ids = [set(state_set.values()) for state_set in state_sets] + + difference = await state_res_store.get_auth_chain_difference(state_sets_ids) + difference.update(difference_from_event_map) return difference diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index ad9bbef9d2..f5c6db900d 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event from synapse.events import make_event_from_dict -from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store +from synapse.state.v2 import ( + _get_auth_chain_difference, + lexicographical_topological_sort, + resolve_events_with_store, +) from synapse.types import EventID from tests import unittest @@ -587,6 +591,128 @@ class SimpleParamStateTestCase(unittest.TestCase): self.assert_dict(self.expected_combined_state, state) +class AuthChainDifferenceTestCase(unittest.TestCase): + """We test that `_get_auth_chain_difference` correctly handles unpersisted + events. + """ + + def test_simple(self): + # Test getting the auth difference for a simple chain with a single + # unpersisted event: + # + # Unpersisted | Persisted + # | + # C -|-> B -> A + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c} + + state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {c.event_id}) + + def test_multiple_unpersisted_chain(self): + # Test getting the auth difference for a simple chain with multiple + # unpersisted events: + # + # Unpersisted | Persisted + # | + # D -> C -|-> B -> A + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + d = FakeEvent( + id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c, d.event_id: d} + + state_sets = [ + {"a": a.event_id, "b": b.event_id}, + {"c": c.event_id, "d": d.event_id}, + ] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {d.event_id, c.event_id}) + + def test_unpersisted_events_different_sets(self): + # Test getting the auth difference for with multiple unpersisted events + # in different branches: + # + # Unpersisted | Persisted + # | + # D --> C -|-> B -> A + # E ----^ -|---^ + # | + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + d = FakeEvent( + id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id], []) + + e = FakeEvent( + id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id, b.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e} + + state_sets = [ + {"a": a.event_id, "b": b.event_id, "e": e.event_id}, + {"c": c.event_id, "d": d.event_id}, + ] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {d.event_id, e.event_id}) + + def pairwise(iterable): "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index d4c3b867e3..71c21d8c75 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -216,6 +216,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertSetEqual(difference, {"a", "b", "c"}) + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + difference = self.get_success( self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) ) -- cgit 1.5.1 From 30fba6210834a4ecd91badf0c8f3eb278b72e746 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Dec 2020 11:09:24 -0500 Subject: Apply an IP range blacklist to push and key revocation requests. (#8821) Replaces the `federation_ip_range_blacklist` configuration setting with an `ip_range_blacklist` setting with wider scope. It now applies to: * Federation * Identity servers * Push notifications * Checking key validitity for third-party invite events The old `federation_ip_range_blacklist` setting is still honored if present, but with reduced scope (it only applies to federation and identity servers). --- changelog.d/8821.bugfix | 1 + docs/sample_config.yaml | 14 ++++--- synapse/app/generic_worker.py | 1 - synapse/config/federation.py | 40 ++++++++++++------- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_server.py | 1 - synapse/federation/transport/client.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 6 +-- synapse/http/client.py | 46 +++++++++++++++------- synapse/http/federation/matrix_federation_agent.py | 16 ++++++-- synapse/http/matrixfederationclient.py | 26 ++++-------- synapse/push/httppusher.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/server.py | 36 +++++++++++++---- tests/api/test_filtering.py | 4 +- tests/app/test_frontend_proxy.py | 2 +- tests/app/test_openid_listener.py | 4 +- tests/crypto/test_keyring.py | 6 ++- tests/handlers/test_device.py | 4 +- tests/handlers/test_directory.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_presence.py | 2 +- tests/handlers/test_profile.py | 2 +- tests/handlers/test_typing.py | 6 +-- .../federation/test_matrix_federation_agent.py | 3 ++ tests/push/test_http.py | 4 +- tests/replication/_base.py | 4 +- tests/replication/test_federation_sender_shard.py | 10 ++--- tests/replication/test_pusher_shard.py | 6 +-- tests/rest/admin/test_admin.py | 2 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v1/test_typing.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/storage/test_e2e_room_keys.py | 2 +- tests/storage/test_purge.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_federation.py | 2 +- tests/test_server.py | 5 ++- 43 files changed, 176 insertions(+), 115 deletions(-) create mode 100644 changelog.d/8821.bugfix (limited to 'tests') diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix new file mode 100644 index 0000000000..8ddfbf31ce --- /dev/null +++ b/changelog.d/8821.bugfix @@ -0,0 +1 @@ +Apply the `federation_ip_range_blacklist` to push and key revocation requests. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 394eb9a3ff..6dbccf5932 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -642,17 +642,19 @@ acme: # - nyc.example.com # - syd.example.com -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. +# Prevent outgoing requests from being sent to the following blacklisted IP address +# CIDR ranges. If this option is not specified, or specified with an empty list, +# no IP range blacklist will be enforced. # -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. +# The blacklist applies to the outbound requests for federation, identity servers, +# push servers, and for checking key validitity for third-party invite events. # # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # -federation_ip_range_blacklist: +# This option replaces federation_ip_range_blacklist in Synapse v1.24.0. +# +ip_range_blacklist: - '127.0.0.0/8' - '10.0.0.0/8' - '172.16.0.0/12' diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1b511890aa..aa12c74358 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -266,7 +266,6 @@ class GenericWorkerPresence(BasePresenceHandler): super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id - self.http_client = hs.get_simple_http_client() self._presence_enabled = hs.config.use_presence diff --git a/synapse/config/federation.py b/synapse/config/federation.py index ffd8fca54e..27ccf61c3c 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -36,22 +36,30 @@ class FederationConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) + ip_range_blacklist = config.get("ip_range_blacklist", []) # Attempt to create an IPSet from the given ranges try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.ip_range_blacklist = IPSet(ip_range_blacklist) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e) + # Always blacklist 0.0.0.0, :: + self.ip_range_blacklist.update(["0.0.0.0", "::"]) + + # The federation_ip_range_blacklist is used for backwards-compatibility + # and only applies to federation and identity servers. If it is not given, + # default to ip_range_blacklist. + federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", ip_range_blacklist + ) + try: + self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) except Exception as e: raise ConfigError( "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) federation_metrics_domains = config.get("federation_metrics_domains") or [] validate_config( @@ -76,17 +84,19 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. + # Prevent outgoing requests from being sent to the following blacklisted IP address + # CIDR ranges. If this option is not specified, or specified with an empty list, + # no IP range blacklist will be enforced. # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. + # The blacklist applies to the outbound requests for federation, identity servers, + # push servers, and for checking key validitity for third-party invite events. # # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # - federation_ip_range_blacklist: + # This option replaces federation_ip_range_blacklist in Synapse v1.24.0. + # + ip_range_blacklist: - '127.0.0.0/8' - '10.0.0.0/8' - '172.16.0.0/12' diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index c04ad77cf9..f23eacc0d7 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -578,7 +578,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): def __init__(self, hs): super().__init__(hs) self.clock = hs.get_clock() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers async def get_keys(self, keys_to_fetch): @@ -748,7 +748,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): def __init__(self, hs): super().__init__(hs) self.clock = hs.get_clock() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() async def get_keys(self, keys_to_fetch): """ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4b6ab470d0..35e345ce70 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -845,7 +845,6 @@ class FederationHandlerRegistry: def __init__(self, hs: "HomeServer"): self.config = hs.config - self.http_client = hs.get_simple_http_client() self.clock = hs.get_clock() self._instance_name = hs.get_instance_name() diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 17a10f622e..abe9168c78 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -35,7 +35,7 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() @log_function def get_room_state_ids(self, destination, room_id, event_id): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b9799090f7..df82e60b33 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -140,7 +140,7 @@ class FederationHandler(BaseHandler): self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_blacklisted_http_client() self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9b3c6b4551..7301c24710 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler): def __init__(self, hs): super().__init__(hs) + # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) - # We create a blacklisting instance of SimpleHttpClient for contacting identity - # servers specified by clients + # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( hs, ip_blacklist=hs.config.federation_ip_range_blacklist ) - self.federation_http_client = hs.get_http_client() + self.federation_http_client = hs.get_federation_http_client() self.hs = hs async def threepid_from_creds( diff --git a/synapse/http/client.py b/synapse/http/client.py index e5b13593f2..df7730078f 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -125,7 +125,7 @@ def _make_scheduler(reactor): return _scheduler -class IPBlacklistingResolver: +class _IPBlacklistingResolver: """ A proxy for reactor.nameResolver which only produces non-blacklisted IP addresses, preventing DNS rebinding attacks on URL preview. @@ -199,6 +199,35 @@ class IPBlacklistingResolver: return r +@implementer(IReactorPluggableNameResolver) +class BlacklistingReactorWrapper: + """ + A Reactor wrapper which will prevent DNS resolution to blacklisted IP + addresses, to prevent DNS rebinding. + """ + + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): + self._reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + self._nameResolver = _IPBlacklistingResolver( + self._reactor, ip_whitelist, ip_blacklist + ) + + def __getattr__(self, attr: str) -> Any: + # Passthrough to the real reactor except for the DNS resolver. + if attr == "nameResolver": + return self._nameResolver + else: + return getattr(self._reactor, attr) + + class BlacklistingAgentWrapper(Agent): """ An Agent wrapper which will prevent access to IP addresses being accessed @@ -292,22 +321,11 @@ class SimpleHttpClient: self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: - real_reactor = hs.get_reactor() # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, self._ip_whitelist, self._ip_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), self._ip_whitelist, self._ip_blacklist ) - - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() else: self.reactor = hs.get_reactor() diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index e77f9587d0..3b756a7dc2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -16,7 +16,7 @@ import logging import urllib.parse from typing import List, Optional -from netaddr import AddrFormatError, IPAddress +from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer from twisted.internet import defer @@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer from synapse.crypto.context_factory import FederationPolicyForHTTPS +from synapse.http.client import BlacklistingAgentWrapper from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -70,6 +71,7 @@ class MatrixFederationAgent: reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, + ip_blacklist: IPSet, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, ): @@ -90,12 +92,18 @@ class MatrixFederationAgent: self.user_agent = user_agent if _well_known_resolver is None: + # Note that the name resolver has already been wrapped in a + # IPBlacklistingResolver by MatrixFederationHttpClient. _well_known_resolver = WellKnownResolver( self._reactor, - agent=Agent( + agent=BlacklistingAgentWrapper( + Agent( + self._reactor, + pool=self._pool, + contextFactory=tls_client_options_factory, + ), self._reactor, - pool=self._pool, - contextFactory=tls_client_options_factory, + ip_blacklist=ip_blacklist, ), user_agent=self.user_agent, ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 4e27f93b7a..c962994727 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -26,11 +26,10 @@ import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from zope.interface import implementer from twisted.internet import defer from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime +from twisted.internet.interfaces import IReactorTime from twisted.internet.task import _EPSILON, Cooperator from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse @@ -45,7 +44,7 @@ from synapse.api.errors import ( from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, - IPBlacklistingResolver, + BlacklistingReactorWrapper, encode_query_args, readBodyToFile, ) @@ -221,31 +220,22 @@ class MatrixFederationHttpClient: self.signing_key = hs.signing_key self.server_name = hs.hostname - real_reactor = hs.get_reactor() - # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), None, hs.config.federation_ip_range_blacklist ) - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() - user_agent = hs.version_string if hs.config.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) user_agent = user_agent.encode("ascii") self.agent = MatrixFederationAgent( - self.reactor, tls_client_options_factory, user_agent + self.reactor, + tls_client_options_factory, + user_agent, + hs.config.federation_ip_range_blacklist, ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index eff0975b6a..0e845212a9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -100,7 +100,7 @@ class HttpPusher: if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_proxied_http_client() + self.http_client = hs.get_proxied_blacklisted_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9cac74ebd8..83beb02b05 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -66,7 +66,7 @@ class MediaRepository: def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() diff --git a/synapse/server.py b/synapse/server.py index b017e3489f..9af759626e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -350,16 +350,45 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_simple_http_client(self) -> SimpleHttpClient: + """ + An HTTP client with no special configuration. + """ return SimpleHttpClient(self) @cache_in_self def get_proxied_http_client(self) -> SimpleHttpClient: + """ + An HTTP client that uses configured HTTP(S) proxies. + """ + return SimpleHttpClient( + self, + http_proxy=os.getenvb(b"http_proxy"), + https_proxy=os.getenvb(b"HTTPS_PROXY"), + ) + + @cache_in_self + def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: + """ + An HTTP client that uses configured HTTP(S) proxies and blacklists IPs + based on the IP range blacklist. + """ return SimpleHttpClient( self, + ip_blacklist=self.config.ip_range_blacklist, http_proxy=os.getenvb(b"http_proxy"), https_proxy=os.getenvb(b"HTTPS_PROXY"), ) + @cache_in_self + def get_federation_http_client(self) -> MatrixFederationHttpClient: + """ + An HTTP client for federation. + """ + tls_client_options_factory = context_factory.FederationPolicyForHTTPS( + self.config + ) + return MatrixFederationHttpClient(self, tls_client_options_factory) + @cache_in_self def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) @@ -514,13 +543,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_pusherpool(self) -> PusherPool: return PusherPool(self) - @cache_in_self - def get_http_client(self) -> MatrixFederationHttpClient: - tls_client_options_factory = context_factory.FederationPolicyForHTTPS( - self.config - ) - return MatrixFederationHttpClient(self, tls_client_options_factory) - @cache_in_self def get_media_repository_resource(self) -> MediaRepositoryResource: # build the media repo resource. This indirects through the HomeServer diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index c98ae75974..bcf1a8010e 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -50,7 +50,9 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - self.addCleanup, http_client=self.mock_http_client, keyring=Mock(), + self.addCleanup, + federation_http_client=self.mock_http_client, + keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 40abe9d72d..43fef5d64a 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserver_to_use=GenericWorkerServer + federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index ea3be95cf1..b260ab734d 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserver_to_use=GenericWorkerServer + federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserver_to_use=SynapseHomeServer + federation_http_client=None, homeserver_to_use=SynapseHomeServer ) return hs diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 697916a019..d146f2254f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - hs = self.setup_test_homeserver(http_client=self.http_client) + hs = self.setup_test_homeserver(federation_http_client=self.http_client) return hs def test_get_keys_from_server(self): @@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): } ] - return self.setup_test_homeserver(http_client=self.http_client, config=config) + return self.setup_test_homeserver( + federation_http_client=self.http_client, config=config + ) def build_perspectives_response( self, server_name: str, signing_key: SigningKey, valid_until_ts: int, diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 875aaec2c6..5dfeccfeb6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -27,7 +27,7 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastore() return hs @@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ee6ef5e6fa..2f1f2a5517 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -42,7 +42,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.mock_registry.register_query_handler = register_query_handler hs = self.setup_test_homeserver( - http_client=None, + federation_http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index bf866dacf3..d0452e1490 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(http_client=None) + hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastore() return hs diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 8ed67640f8..0794b32c9c 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, federation_sender=Mock() + "server", federation_http_client=None, federation_sender=Mock() ) return hs diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a69fa28b41..ea2bcf2655 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -44,7 +44,7 @@ class ProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, - http_client=None, + federation_http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_server=Mock(), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index abbdf2d524..36086ca836 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -70,7 +70,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( notifier=Mock(), - http_client=mock_federation_client, + federation_http_client=mock_federation_client, keyring=mock_keyring, replication_streams={}, ) @@ -192,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - put_json = self.hs.get_http_client().put_json + put_json = self.hs.get_federation_http_client().put_json put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", @@ -270,7 +270,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - put_json = self.hs.get_http_client().put_json + put_json = self.hs.get_federation_http_client().put_json put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 8b5ad4574f..626acdcaa3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -17,6 +17,7 @@ import logging from mock import Mock import treq +from netaddr import IPSet from service_identity import VerificationError from zope.interface import implementer @@ -103,6 +104,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=self.tls_factory, user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -736,6 +738,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, diff --git a/tests/push/test_http.py b/tests/push/test_http.py index f118430309..e8cea39c83 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -49,7 +49,9 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, proxied_http_client=m) + hs = self.setup_test_homeserver( + config=config, proxied_blacklisted_http_client=m + ) return hs diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 295c5d58a6..728de28277 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -67,7 +67,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( - http_client=None, + federation_http_client=None, homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, @@ -264,7 +264,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): worker_app: Type of worker, e.g. `synapse.app.federation_sender`. extra_config: Any extra config to use for this instances. **kwargs: Options that get passed to `self.setup_test_homeserver`, - useful to e.g. pass some mocks for things like `http_client` + useful to e.g. pass some mocks for things like `federation_http_client` Returns: The new worker HomeServer instance. diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 779745ae9d..fffdb742c8 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", {"send_federation": True}, - http_client=mock_client, + federation_http_client=mock_client, ) user = self.register_user("user", "pass") @@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user2", "pass") @@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user3", "pass") diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 67c27a089f..f894bcd6e7 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", {"start_pushers": True}, - proxied_http_client=http_client_mock, + proxied_blacklisted_http_client=http_client_mock, ) event_id = self._create_pusher_and_send_msg("user") @@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher1", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock1, + proxied_blacklisted_http_client=http_client_mock1, ) http_client_mock2 = Mock(spec_set=["post_json_get_json"]) @@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher2", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock2, + proxied_blacklisted_http_client=http_client_mock2, ) # We choose a user name that we know should go to pusher1. diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 4f76f8f768..67d8878395 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -210,7 +210,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): } config["media_storage_providers"] = [provider_config] - hs = self.setup_test_homeserver(config=config, http_client=client) + hs = self.setup_test_homeserver(config=config, federation_http_client=client) return hs diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 5d5c24d01c..11cd8efe21 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( "red", - http_client=None, + federation_http_client=None, federation_client=Mock(), presence_handler=presence_handler, ) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 383a9eafac..2a3b483eaf 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, "test", - http_client=None, + federation_http_client=None, resource_for_client=self.mock_resource, federation=Mock(), federation_client=Mock(), diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 49f1073c88..e67de41c18 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -45,7 +45,7 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), + "red", federation_http_client=None, federation_client=Mock(), ) self.hs.get_federation_handler = Mock() diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index bbd30f594b..ae0207366b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), + "red", federation_http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index fbcf8d5b86..5e90d656f7 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -39,7 +39,7 @@ from tests.utils import default_config class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - return self.setup_test_homeserver(http_client=self.http_client) + return self.setup_test_homeserver(federation_http_client=self.http_client) def create_test_resource(self): return create_resource_tree( @@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): } ] self.hs2 = self.setup_test_homeserver( - http_client=self.http_client2, config=config + federation_http_client=self.http_client2, config=config ) # wire up outbound POST /key/v2/query requests from hs2 so that they diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2a3b2a8f27..4c749f1a61 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -214,7 +214,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): } config["media_storage_providers"] = [provider_config] - hs = self.setup_test_homeserver(config=config, http_client=client) + hs = self.setup_test_homeserver(config=config, federation_http_client=client) return hs diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 35dafbb904..3d7760d5d9 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -26,7 +26,7 @@ room_key = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.store = hs.get_datastore() return hs diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index cc1f3c53c5..a06ad2c03e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase): servlets = [room.register_servlets] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) return hs def prepare(self, reactor, clock, hs): diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d4f9e809db..fd0add5db3 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -34,7 +34,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): config = self.default_config() config["redaction_retention_period"] = "30d" return self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None, config=config + resource_for_federation=Mock(), federation_http_client=None, config=config ) def prepare(self, reactor, clock, hs): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index ff972daeaa..5ba1db2332 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -36,7 +36,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + resource_for_federation=Mock(), federation_http_client=None ) return hs diff --git a/tests/test_federation.py b/tests/test_federation.py index 1ce4ea3a01..fa45f8b3b7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, - http_client=self.http_client, + federation_http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) diff --git a/tests/test_server.py b/tests/test_server.py index c387a85f2e..6b2d2f0401 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + federation_http_client=None, + clock=self.hs_clock, + reactor=self.reactor, ) def test_handler_for_request(self): -- cgit 1.5.1 From 7ea85302f3ce59cdb38bceb1c2aeea2d690e2dbb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 15:26:25 +0000 Subject: fix up various test cases A few test cases were relying on being able to mount non-client servlets on the test resource. it's better to give them their own Resources. --- synapse/federation/transport/server.py | 2 +- tests/handlers/test_typing.py | 11 ++++++++--- tests/replication/_base.py | 14 ++++++++------ tests/server.py | 3 ++- tests/unittest.py | 25 +++++++++++++++++++------ 5 files changed, 38 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b53e7a20ec..434718ddfc 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N Args: hs (synapse.server.HomeServer): homeserver - resource (TransportLayerServer): resource class to register to + resource (JsonResource): resource class to register to authenticator (Authenticator): authenticator to use ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use servlet_groups (list[str], optional): List of servlet groups to register. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index abbdf2d524..9a085ccaf6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -15,18 +15,20 @@ import json +from typing import Dict from mock import ANY, Mock, call from twisted.internet import defer +from twisted.web.resource import Resource from synapse.api.errors import AuthError +from synapse.federation.transport.server import TransportLayerServer from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable from tests.unittest import override_config -from tests.utils import register_federation_servlets # Some local users to test with U_APPLE = UserID.from_string("@apple:test") @@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content): class TypingNotificationsTestCase(unittest.HomeserverTestCase): - servlets = [register_federation_servlets] - def make_homeserver(self, reactor, clock): # we mock out the keyring so as to skip the authentication check on the # federation API call. @@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): return hs + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TransportLayerServer(self.hs) + return d + def prepare(self, reactor, clock, hs): mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 295c5d58a6..79738ab46f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import attr @@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel +from twisted.web.resource import Resource from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, @@ -28,7 +29,7 @@ from synapse.app.generic_worker import ( ) from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.replication.http import ReplicationRestResource, streams +from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): if not hiredis: skip = "Requires hiredis" - servlets = [ - streams.register_servlets, - ] - def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) @@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/replication"] = ReplicationRestResource(self.hs) + return d + def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.generic_worker" diff --git a/tests/server.py b/tests/server.py index a51ad0c14e..eee970c43c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -216,8 +216,9 @@ def make_request( and not path.startswith(b"/_matrix") and not path.startswith(b"/_synapse") ): + if path.startswith(b"/"): + path = path[1:] path = b"/_matrix/client/r0/" + path - path = path.replace(b"//", b"/") if not path.startswith(b"/"): path = b"/" + path diff --git a/tests/unittest.py b/tests/unittest.py index 425b39b1d1..102b0a1f34 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -705,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase): A federating homeserver that authenticates incoming requests as `other.example.com`. """ - def prepare(self, reactor, clock, homeserver): + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TestTransportLayerServer(self.hs) + return d + + +class TestTransportLayerServer(JsonResource): + """A test implementation of TransportLayerServer + + authenticates incoming requests as `other.example.com`. + """ + + def __init__(self, hs): + super().__init__(hs) + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") + authenticator = Authenticator() + ratelimiter = FederationRateLimiter( - clock, + hs.get_clock(), FederationRateLimitConfig( window_size=1, sleep_limit=1, @@ -720,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase): concurrent_requests=1000, ), ) - federation_server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - return super().prepare(reactor, clock, homeserver) + federation_server.register_servlets(hs, self, authenticator, ratelimiter) def override_config(extra_config): -- cgit 1.5.1 From 90cf1eec44940f8ede4a7f0490da43021c84a13a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 15:12:02 +0000 Subject: Remove redundant mocking --- tests/api/test_filtering.py | 18 ++---------------- tests/handlers/test_directory.py | 2 -- tests/handlers/test_profile.py | 2 -- tests/storage/test_redaction.py | 11 +++-------- tests/storage/test_roommember.py | 8 -------- 5 files changed, 5 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index bcf1a8010e..279c94a03d 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -16,8 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock - import jsonschema from twisted.internet import defer @@ -28,7 +26,7 @@ from synapse.api.filtering import Filter from synapse.events import make_event_from_dict from tests import unittest -from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver +from tests.utils import setup_test_homeserver user_localpart = "test_user" @@ -42,21 +40,9 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): - self.mock_federation_resource = MockHttpResource() - - self.mock_http_client = Mock(spec=[]) - self.mock_http_client.put_json = DeferredMockCallable() - - hs = yield setup_test_homeserver( - self.addCleanup, - federation_http_client=self.mock_http_client, - keyring=Mock(), - ) - + hs = setup_test_homeserver(self.addCleanup) self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() def test_errors_on_invalid_filters(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 2f1f2a5517..770d225ed5 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.mock_registry.register_query_handler = register_query_handler hs = self.setup_test_homeserver( - federation_http_client=None, - resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index ea2bcf2655..919547556b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, - federation_http_client=None, - resource_for_federation=Mock(), federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index fd0add5db3..a6303bf0ee 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -14,9 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from mock import Mock - from canonicaljson import json from twisted.internet import defer @@ -30,12 +27,10 @@ from tests.utils import create_room class RedactionTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - config = self.default_config() + def default_config(self): + config = super().default_config() config["redaction_retention_period"] = "30d" - return self.setup_test_homeserver( - resource_for_federation=Mock(), federation_http_client=None, config=config - ) + return config def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5ba1db2332..d2aed66f6d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock - from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room @@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - resource_for_federation=Mock(), federation_http_client=None - ) - return hs - def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event -- cgit 1.5.1 From 76469898ee797db232adaccb9fd547bddab2fe59 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 18:22:01 +0000 Subject: Factor out FakeResponse from test_oidc --- tests/handlers/test_oidc.py | 17 +---------------- tests/test_utils/__init__.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index d485af52fd..23ea1e3543 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse from mock import Mock, patch -import attr import pymacaroons -from twisted.python.failure import Failure -from twisted.web._newclient import ResponseDone - from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException from synapse.types import UserID +from tests.test_utils import FakeResponse from tests.unittest import HomeserverTestCase, override_config - -@attr.s -class FakeResponse: - code = attr.ib() - body = attr.ib() - phrase = attr.ib() - - def deliverBody(self, protocol): - protocol.dataReceived(self.body) - protocol.connectionLost(Failure(ResponseDone())) - - # These are a few constants that are used as config parameters in the tests. ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index d232b72264..6873d45eb6 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -22,6 +22,11 @@ import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar +import attr + +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone + TV = TypeVar("TV") @@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]: sys.unraisablehook = unraisablehook # type: ignore return cleanup + + +@attr.s +class FakeResponse: + """A fake twisted.web.IResponse object + + there is a similar class at treq.test.test_response, but it lacks a `phrase` + attribute, and didn't support deliverBody until recently. + """ + + # HTTP response code + code = attr.ib(type=int) + + # HTTP response phrase (eg b'OK' for a 200) + phrase = attr.ib(type=bytes) + + # body of the response + body = attr.ib(type=bytes) + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) -- cgit 1.5.1 From c834f1d67a7feeaebc353d0170f99a618bf32b5b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 17:40:31 +0000 Subject: remove unused `resource_for_federation` This is now only used in `test_typing`, so move it there. --- tests/handlers/test_typing.py | 14 +++++++++++++- tests/utils.py | 17 ----------------- 2 files changed, 13 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 36086ca836..294b71e8e2 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -21,12 +21,13 @@ from mock import ANY, Mock, call from twisted.internet import defer from synapse.api.errors import AuthError +from synapse.federation.transport import server as federation_server from synapse.types import UserID, create_requester +from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest from tests.test_utils import make_awaitable from tests.unittest import override_config -from tests.utils import register_federation_servlets # Some local users to test with U_APPLE = UserID.from_string("@apple:test") @@ -52,6 +53,17 @@ def _make_edu_transaction_json(edu_type, content): return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") +def register_federation_servlets(hs, resource): + federation_server.register_servlets( + hs, + resource=resource, + authenticator=federation_server.Authenticator(hs), + ratelimiter=FederationRateLimiter( + hs.get_clock(), config=hs.config.rc_federation + ), + ) + + class TypingNotificationsTestCase(unittest.HomeserverTestCase): servlets = [register_federation_servlets] diff --git a/tests/utils.py b/tests/utils.py index c8d3ffbaba..f5f56e5a79 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,7 +34,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer @@ -42,7 +41,6 @@ from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database -from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. # @@ -342,24 +340,9 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash - fed = kwargs.get("resource_for_federation", None) - if fed: - register_federation_servlets(hs, fed) - return hs -def register_federation_servlets(hs, resource): - federation_server.register_servlets( - hs, - resource=resource, - authenticator=federation_server.Authenticator(hs), - ratelimiter=FederationRateLimiter( - hs.get_clock(), config=hs.config.rc_federation - ), - ) - - def get_mock_call_args(pattern_func, mock_func): """ Return the arguments the mock function was called with interpreted by the pattern functions argument list. -- cgit 1.5.1 From b751624ff8cd110ebfd312f13b330c9ac7cd0d86 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 15:13:05 +0000 Subject: remove unused DeferredMockCallable --- tests/utils.py | 91 +--------------------------------------------------------- 1 file changed, 1 insertion(+), 90 deletions(-) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index f5f56e5a79..977eeaf6ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,13 +20,12 @@ import os import time import uuid import warnings -from inspect import getcallargs from typing import Type from urllib import parse as urlparse from mock import Mock, patch -from twisted.internet import defer, reactor +from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error @@ -343,14 +342,6 @@ def setup_test_homeserver( return hs -def get_mock_call_args(pattern_func, mock_func): - """ Return the arguments the mock function was called with interpreted - by the pattern functions argument list. - """ - invoked_args, invoked_kargs = mock_func.call_args - return getcallargs(pattern_func, *invoked_args, **invoked_kargs) - - def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} @@ -536,86 +527,6 @@ class MockClock: return d -def _format_call(args, kwargs): - return ", ".join( - ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()] - ) - - -class DeferredMockCallable: - """A callable instance that stores a set of pending call expectations and - return values for them. It allows a unit test to assert that the given set - of function calls are eventually made, by awaiting on them to be called. - """ - - def __init__(self): - self.expectations = [] - self.calls = [] - - def __call__(self, *args, **kwargs): - self.calls.append((args, kwargs)) - - if not self.expectations: - raise ValueError( - "%r has no pending calls to handle call(%s)" - % (self, _format_call(args, kwargs)) - ) - - for (call, result, d) in self.expectations: - if args == call[1] and kwargs == call[2]: - d.callback(None) - return result - - failure = AssertionError( - "Was not expecting call(%s)" % (_format_call(args, kwargs)) - ) - - for _, _, d in self.expectations: - try: - d.errback(failure) - except Exception: - pass - - raise failure - - def expect_call_and_return(self, call, result): - self.expectations.append((call, result, defer.Deferred())) - - @defer.inlineCallbacks - def await_calls(self, timeout=1000): - deferred = defer.DeferredList( - [d for _, _, d in self.expectations], fireOnOneErrback=True - ) - - timer = reactor.callLater( - timeout / 1000, - deferred.errback, - AssertionError( - "%d pending calls left: %s" - % ( - len([e for e in self.expectations if not e[2].called]), - [e for e in self.expectations if not e[2].called], - ) - ), - ) - - yield deferred - - timer.cancel() - - self.calls = [] - - def assert_had_no_calls(self): - if self.calls: - calls = self.calls - self.calls = [] - - raise AssertionError( - "Expected not to received any calls, got:\n" - + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls]) - ) - - async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room """ -- cgit 1.5.1 From 0bac276890567ef3a3fafd7f5b7b5cac91a1031b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 1 Dec 2020 00:15:36 +0000 Subject: UIA: offer only available auth flows During user-interactive auth, do not offer password auth to users with no password, nor SSO auth to users with no SSO. Fixes #7559. --- synapse/handlers/auth.py | 58 ++++++++--- synapse/storage/databases/main/registration.py | 25 +++++ .../delta/58/25user_external_ids_user_id_idx.sql | 17 +++ tests/rest/client/v1/utils.py | 116 ++++++++++++++++++++- tests/rest/client/v2_alpha/test_auth.py | 94 ++++++++++++++--- tests/server.py | 1 + 6 files changed, 278 insertions(+), 33 deletions(-) create mode 100644 synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c7dc07008a..2e72298e05 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -193,9 +193,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._sso_enabled = ( - hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled - ) + self._password_localdb_enabled = hs.config.password_localdb_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -205,7 +203,7 @@ class AuthHandler(BaseHandler): # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = [] - if hs.config.password_localdb_enabled: + if self._password_localdb_enabled: login_types.append(LoginType.PASSWORD) for provider in self.password_providers: @@ -219,14 +217,6 @@ class AuthHandler(BaseHandler): self._supported_login_types = login_types - # Login types and UI Auth types have a heavy overlap, but are not - # necessarily identical. Login types have SSO (and other login types) - # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. - ui_auth_types = login_types.copy() - if self._sso_enabled: - ui_auth_types.append(LoginType.SSO) - self._supported_ui_auth_types = ui_auth_types - # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( @@ -339,7 +329,10 @@ class AuthHandler(BaseHandler): self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_ui_auth_types] + supported_ui_auth_types = await self._get_available_ui_auth_types( + requester.user + ) + flows = [[login_type] for login_type in supported_ui_auth_types] try: result, params, session_id = await self.check_ui_auth( @@ -351,7 +344,7 @@ class AuthHandler(BaseHandler): raise # find the completed login type - for login_type in self._supported_ui_auth_types: + for login_type in supported_ui_auth_types: if login_type not in result: continue @@ -367,6 +360,41 @@ class AuthHandler(BaseHandler): return params, session_id + async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: + """Get a list of the authentication types this user can use + """ + + ui_auth_types = set() + + # if the HS supports password auth, and the user has a non-null password, we + # support password auth + if self._password_localdb_enabled and self._password_enabled: + lookupres = await self._find_user_id_and_pwd_hash(user.to_string()) + if lookupres: + _, password_hash = lookupres + if password_hash: + ui_auth_types.add(LoginType.PASSWORD) + + # also allow auth from password providers + for provider in self.password_providers: + for t in provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) + + # if sso is enabled, allow the user to log in via SSO iff they have a mapping + # from sso to mxid. + if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: + if await self.store.get_external_ids_by_user(user.to_string()): + ui_auth_types.add(LoginType.SSO) + + # Our CAS impl does not (yet) correctly register users in user_external_ids, + # so always offer that if it's available. + if self.hs.config.cas.cas_enabled: + ui_auth_types.add(LoginType.SSO) + + return ui_auth_types + def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler): if result: return result - if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: + if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True # we've already checked that there is a (valid) password field diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index fedb8a6c26..ff96c34c2e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_external_id", ) + async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]: + """Look up external ids for the given user + + Args: + mxid: the MXID to be looked up + + Returns: + Tuples of (auth_provider, external_id) + """ + res = await self.db_pool.simple_select_list( + table="user_external_ids", + keyvalues={"user_id": mxid}, + retcols=("auth_provider", "external_id"), + desc="get_external_ids_by_user", + ) + return [(r["auth_provider"], r["external_id"]) for r in res] + async def count_all_users(self): """Counts all users registered on the homeserver.""" @@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) + self.db_pool.updates.register_background_index_update( + "user_external_ids_user_id_idx", + index_name="user_external_ids_user_id_idx", + table="user_external_ids", + columns=["user_id"], + unique=False, + ) + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql new file mode 100644 index 0000000000..8f5e65aa71 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5825, 'user_external_ids_user_id_idx', '{}'); diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 737c38c396..5a18af8d34 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -2,7 +2,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,17 +17,23 @@ # limitations under the License. import json +import re import time +import urllib.parse from typing import Any, Dict, Optional +from mock import patch + import attr from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.types import JsonDict from tests.server import FakeSite, make_request +from tests.test_utils import FakeResponse @attr.s @@ -344,3 +350,111 @@ class RestHelper: ) return channel.json_body + + def login_via_oidc(self, remote_user_id: str) -> JsonDict: + """Log in (as a new user) via OIDC + + Returns the result of the final token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + client_redirect_url = "https://x" + + # first hit the redirect url (which will issue a cookie and state) + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + # that will redirect to the OIDC IdP, but we skip that and go straight + # back to synapse's OIDC callback resource. However, we do need the "state" + # param that synapse passes to the IdP via query params, and the cookie that + # synapse passes to the client. + assert channel.code == 302 + oauth_uri = channel.headers.getRawHeaders("Location")[0] + params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) + redirect_uri = "%s?%s" % ( + urllib.parse.urlparse(params["redirect_uri"][0]).path, + urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + ) + cookies = {} + for h in channel.headers.getRawHeaders("Set-Cookie"): + parts = h.split(";") + k, v = parts[0].split("=", maxsplit=1) + cookies[k] = v + + # before we hit the callback uri, stub out some methods in the http client so + # that we don't have to handle full HTTPS requests. + + # (expected url, json response) pairs, in the order we expect them. + expected_requests = [ + # first we get a hit to the token endpoint, which we tell to return + # a dummy OIDC access token + ("https://issuer.test/token", {"access_token": "TEST"}), + # and then one to the user_info endpoint, which returns our remote user id. + ("https://issuer.test/userinfo", {"sub": remote_user_id}), + ] + + async def mock_req(method: str, uri: str, data=None, headers=None): + (expected_uri, resp_obj) = expected_requests.pop(0) + assert uri == expected_uri + resp = FakeResponse( + code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), + ) + return resp + + with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + # now hit the callback URI with the right params and a made-up code + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + redirect_uri, + custom_headers=[ + ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() + ], + ) + + # expect a confirmation page + assert channel.code == 200 + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.result["body"].decode("utf-8"), + ) + assert m + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + assert channel.code == 200 + return channel.json_body + + +# an 'oidc_config' suitable for login_with_oidc. +TEST_OIDC_CONFIG = { + "enabled": True, + "discover": False, + "issuer": "https://issuer.test", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://z", + "token_endpoint": "https://issuer.test/token", + "userinfo_endpoint": "https://issuer.test/userinfo", + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 77246e478f..ac67a9de29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List, Union from twisted.internet.defer import succeed @@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.http.site import SynapseRequest from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.types import JsonDict +from synapse.rest.oidc import OIDCResource +from synapse.types import JsonDict, UserID from tests import unittest +from tests.rest.client.v1.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel @@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] + def default_config(self): + config = super().default_config() + + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True + + config["oidc_config"] = oidc_config + config["public_baseurl"] = "https://synapse.test" + return config + + def create_resource_dict(self): + resource_dict = super().create_resource_dict() + # mount the OIDC resource at /_synapse/oidc + resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + return resource_dict + def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.user_tok = self.login("test", self.user_pass) - def get_device_ids(self) -> List[str]: + def get_device_ids(self, access_token: str) -> List[str]: # Get the list of devices so one can be deleted. - request, channel = self.make_request( - "GET", "devices", access_token=self.user_tok, - ) # type: SynapseRequest, FakeChannel - - # Get the ID of the device. - self.assertEqual(request.code, 200) + _, channel = self.make_request("GET", "devices", access_token=access_token,) + self.assertEqual(channel.code, 200) return [d["device_id"] for d in channel.json_body["devices"]] def delete_device( - self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + self, + access_token: str, + device: str, + expected_response: int, + body: Union[bytes, JsonDict] = b"", ) -> FakeChannel: """Delete an individual device.""" request, channel = self.make_request( - "DELETE", "devices/" + device, body, access_token=self.user_tok + "DELETE", "devices/" + device, body, access_token=access_token, ) # type: SynapseRequest, FakeChannel # Ensure the response is sane. @@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids()[0] + device_id = self.get_device_ids(self.user_tok)[0] # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(device_id, 401) + channel = self.delete_device(self.user_tok, device_id, 401) # Grab the session session = channel.json_body["session"] @@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( + self.user_tok, device_id, 200, { @@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - device_id = self.get_device_ids()[0] - channel = self.delete_device(device_id, 401) + device_id = self.get_device_ids(self.user_tok)[0] + channel = self.delete_device(self.user_tok, device_id, 401) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( + self.user_tok, device_id, 200, { @@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Create a second login. self.login("test", self.user_pass) - device_ids = self.get_device_ids() + device_ids = self.get_device_ids(self.user_tok) self.assertEqual(len(device_ids), 2) # Attempt to delete the first device. @@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase): # Create a second login. self.login("test", self.user_pass) - device_ids = self.get_device_ids() + device_ids = self.get_device_ids(self.user_tok) self.assertEqual(len(device_ids), 2) # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(device_ids[0], 401) + channel = self.delete_device(self.user_tok, device_ids[0], 401) # Grab the session session = channel.json_body["session"] @@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. self.delete_device( + self.user_tok, device_ids[1], 403, { @@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase): }, }, ) + + def test_does_not_offer_password_for_sso_user(self): + login_resp = self.helper.login_via_oidc("username") + user_tok = login_resp["access_token"] + device_id = login_resp["device_id"] + + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(user_tok, device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) + + def test_does_not_offer_sso_for_password_user(self): + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + device_ids = self.get_device_ids(self.user_tok) + channel = self.delete_device(self.user_tok, device_ids[0], 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + + def test_offers_both_flows_for_upgraded_user(self): + """A user that had a password and then logged in with SSO should get both flows + """ + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + device_ids = self.get_device_ids(self.user_tok) + channel = self.delete_device(self.user_tok, device_ids[0], 401) + + flows = channel.json_body["flows"] + # we have no particular expectations of ordering here + self.assertIn({"stages": ["m.login.password"]}, flows) + self.assertIn({"stages": ["m.login.sso"]}, flows) + self.assertEqual(len(flows), 2) diff --git a/tests/server.py b/tests/server.py index eee970c43c..4faf32e335 100644 --- a/tests/server.py +++ b/tests/server.py @@ -259,6 +259,7 @@ def make_request( for k, v in custom_headers: req.requestHeaders.addRawHeader(k, v) + req.parseCookies() req.requestReceived(method, path, b"1.1") if await_result: -- cgit 1.5.1 From f347f0cd581bb2dd8322a4e97c75d6b25d79db8a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 2 Dec 2020 18:58:25 +0000 Subject: remove unused FakeResponse (#8864) --- changelog.d/8864.misc | 1 + tests/rest/media/v1/test_url_preview.py | 26 -------------------------- 2 files changed, 1 insertion(+), 26 deletions(-) create mode 100644 changelog.d/8864.misc (limited to 'tests') diff --git a/changelog.d/8864.misc b/changelog.d/8864.misc new file mode 100644 index 0000000000..a780883495 --- /dev/null +++ b/changelog.d/8864.misc @@ -0,0 +1 @@ +Remove unused `FakeResponse` class from unit tests. diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index ccdc8c2ecf..529b6bcded 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -18,41 +18,15 @@ import re from mock import patch -import attr - from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError -from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol -from twisted.web._newclient import ResponseDone from tests import unittest from tests.server import FakeTransport -@attr.s -class FakeResponse: - version = attr.ib() - code = attr.ib() - phrase = attr.ib() - headers = attr.ib() - body = attr.ib() - absoluteURI = attr.ib() - - @property - def request(self): - @attr.s - class FakeTransport: - absoluteURI = self.absoluteURI - - return FakeTransport() - - def deliverBody(self, protocol): - protocol.dataReceived(self.body) - protocol.connectionLost(Failure(ResponseDone())) - - class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True -- cgit 1.5.1 From cf3b8156bec7d137894ebc3f6998c3c81a3854aa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 3 Dec 2020 15:41:19 +0000 Subject: Fix errorcode for disabled registration (#8867) The spec says we should return `M_FORBIDDEN` when someone tries to register and registration is disabled. --- changelog.d/8867.bugfix | 1 + synapse/rest/client/v2_alpha/register.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8867.bugfix (limited to 'tests') diff --git a/changelog.d/8867.bugfix b/changelog.d/8867.bugfix new file mode 100644 index 0000000000..f2414ff111 --- /dev/null +++ b/changelog.d/8867.bugfix @@ -0,0 +1 @@ +Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled. diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a89ae6ddf9..9041e7ed76 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet): # == Normal User Registration == (everyone else) if not self._registration_enabled: - raise SynapseError(403, "Registration has been disabled") + raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) # For regular registration, convert the provided username to lowercase # before attempting to register it. This should mean that people who try diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8f0c2430e8..bcb21d0ced 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -121,6 +121,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): self.hs.config.macaroon_secret_key = "test" -- cgit 1.5.1 From b774c555d821170e4f16de7d48f01484c3a1d740 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Dec 2020 10:51:56 -0500 Subject: Add additional validation to pusher URLs. (#8865) Pusher URLs now must end in `/_matrix/push/v1/notify` per the specification. --- changelog.d/8865.bugfix | 1 + synapse/push/__init__.py | 3 +- synapse/push/httppusher.py | 16 ++++- tests/push/test_http.py | 106 ++++++++++++++++++++++++++------- tests/replication/test_pusher_shard.py | 8 +-- tests/rest/admin/test_user.py | 2 +- 6 files changed, 106 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8865.bugfix (limited to 'tests') diff --git a/changelog.d/8865.bugfix b/changelog.d/8865.bugfix new file mode 100644 index 0000000000..a1e625f552 --- /dev/null +++ b/changelog.d/8865.bugfix @@ -0,0 +1 @@ +Add additional validation to pusher URLs to be compliant with the specification. diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 5a437f9810..e462fb2e13 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -15,5 +15,4 @@ class PusherConfigException(Exception): - def __init__(self, msg): - super().__init__(msg) + """An error occurred when creating a pusher.""" diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 0e845212a9..6a0ee8274c 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import urllib.parse from prometheus_client import Counter @@ -97,9 +98,22 @@ class HttpPusher: if self.data is None: raise PusherConfigException("data can not be null for HTTP pusher") + # Validate that there's a URL and it is of the proper form. if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") - self.url = self.data["url"] + + url = self.data["url"] + if not isinstance(url, str): + raise PusherConfigException("'url' must be a string") + url_parts = urllib.parse.urlparse(url) + # Note that the specification also says the scheme must be HTTPS, but + # it isn't up to the homeserver to verify that. + if url_parts.path != "/_matrix/push/v1/notify": + raise PusherConfigException( + "'url' must have a path of '/_matrix/push/v1/notify'" + ) + + self.url = url self.http_client = hs.get_proxied_blacklisted_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index e8cea39c83..8b4af74c51 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable +from synapse.push import PusherConfigException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import receipts @@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase): user_id = True hijack_auth = False + def default_config(self): + config = super().default_config() + config["start_pushers"] = True + return config + def make_homeserver(self, reactor, clock): self.push_attempts = [] @@ -46,14 +52,48 @@ class HTTPPusherTests(HomeserverTestCase): m.post_json_get_json = post_json_get_json - config = self.default_config() - config["start_pushers"] = True + hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m) + + return hs - hs = self.setup_test_homeserver( - config=config, proxied_blacklisted_http_client=m + def test_invalid_configuration(self): + """Invalid push configurations should be rejected.""" + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) ) + token_id = user_tuple.token_id - return hs + def test_data(data): + self.get_failure( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data=data, + ), + PusherConfigException, + ) + + # Data must be provided with a URL. + test_data(None) + test_data({}) + test_data({"url": 1}) + # A bare domain name isn't accepted. + test_data({"url": "example.com"}) + # A URL without a path isn't accepted. + test_data({"url": "http://example.com"}) + # A url with an incorrect path isn't accepted. + test_data({"url": "http://example.com/foo"}) def test_sends_http(self): """ @@ -84,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -119,7 +159,9 @@ class HTTPPusherTests(HomeserverTestCase): # One push was attempted to be sent -- it'll be the first message self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual( self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!" ) @@ -139,7 +181,9 @@ class HTTPPusherTests(HomeserverTestCase): # Now it'll try and send the second push message, which will be the second one self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual( self.push_attempts[1][2]["notification"]["content"]["body"], "There!" ) @@ -196,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -232,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Add yet another person — we want to make this room not a 1:1 @@ -270,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_one_to_one_only(self): @@ -312,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -328,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority — this is a one-to-one room self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Yet another user joins @@ -347,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") @@ -394,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -410,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Send another event, this time with no mention @@ -419,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") @@ -467,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -487,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Send another event, this time as someone without the power of @room @@ -498,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") @@ -572,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -591,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) # Check that the unread count for the room is 0 # diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index f894bcd6e7..800ad94a04 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "https://push.example.com/push"}, + data={"url": "https://push.example.com/_matrix/push/v1/notify"}, ) ) @@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock.post_json_get_json.assert_called_once() self.assertEqual( http_client_mock.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, @@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock2.post_json_get_json.assert_not_called() self.assertEqual( http_client_mock1.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, @@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock2.post_json_get_json.assert_called_once() self.assertEqual( http_client_mock2.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 54d46f4bd3..35c546aa69 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1256,7 +1256,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "https://example.com/_matrix/push/v1/notify"}, ) ) -- cgit 1.5.1 From df4b1e9c74d56d79c274149b0dfb0fd5305c7659 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Dec 2020 15:52:49 +0000 Subject: Pass room_id to get_auth_chain_difference (#8879) This is so that we can choose which algorithm to use based on the room ID. --- changelog.d/8879.misc | 1 + synapse/state/__init__.py | 4 ++-- synapse/state/v2.py | 9 +++++++-- synapse/storage/databases/main/event_federation.py | 4 +++- tests/state/test_v2.py | 14 ++++++++++---- tests/storage/test_event_federation.py | 18 ++++++++++-------- 6 files changed, 33 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8879.misc (limited to 'tests') diff --git a/changelog.d/8879.misc b/changelog.d/8879.misc new file mode 100644 index 0000000000..6f9516b314 --- /dev/null +++ b/changelog.d/8879.misc @@ -0,0 +1 @@ +Pass `room_id` to `get_auth_chain_difference`. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1fa3b280b4..84f59c7d85 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -783,7 +783,7 @@ class StateResolutionStore: ) def get_auth_chain_difference( - self, state_sets: List[Set[str]] + self, room_id: str, state_sets: List[Set[str]] ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -796,4 +796,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(state_sets) + return self.store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index ffc504ce77..f85124bf81 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -97,7 +97,9 @@ async def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference( + room_id, state_sets, event_map, state_res_store + ) full_conflicted_set = set( itertools.chain( @@ -236,6 +238,7 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( + room_id: str, state_sets: Sequence[StateMap[str]], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", @@ -332,7 +335,9 @@ async def _get_auth_chain_difference( difference_from_event_map = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] - difference = await state_res_store.get_auth_chain_difference(state_sets_ids) + difference = await state_res_store.get_auth_chain_difference( + room_id, state_sets_ids + ) difference.update(difference_from_event_map) return difference diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 2e07c37340..ebffd89251 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: + async def get_auth_chain_difference( + self, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f5c6db900d..09f4f32a02 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -623,7 +623,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {c.event_id}) @@ -662,7 +664,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {d.event_id, c.event_id}) @@ -707,7 +711,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {d.event_id, e.event_id}) @@ -773,7 +779,7 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, auth_sets): + def get_auth_chain_difference(self, room_id, auth_sets): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 71c21d8c75..482506d731 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -202,39 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Now actually test that various combinations give the right result: difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b", "c"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "d", "e"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) ) self.assertSetEqual(difference, {"a", "b"}) - difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) self.assertSetEqual(difference, set()) -- cgit 1.5.1 From 96358cb42410a4be6268eaa3ffec229c550208ea Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Dec 2020 10:56:28 -0500 Subject: Add authentication to replication endpoints. (#8853) Authentication is done by checking a shared secret provided in the Synapse configuration file. --- changelog.d/8853.feature | 1 + docs/sample_config.yaml | 7 ++ docs/workers.md | 6 +- synapse/config/workers.py | 10 +++ synapse/replication/http/_base.py | 47 ++++++++-- tests/replication/test_auth.py | 119 ++++++++++++++++++++++++++ tests/replication/test_client_reader_shard.py | 9 +- 7 files changed, 184 insertions(+), 15 deletions(-) create mode 100644 changelog.d/8853.feature create mode 100644 tests/replication/test_auth.py (limited to 'tests') diff --git a/changelog.d/8853.feature b/changelog.d/8853.feature new file mode 100644 index 0000000000..63c59f4ff2 --- /dev/null +++ b/changelog.d/8853.feature @@ -0,0 +1 @@ +Add optional HTTP authentication to replication endpoints. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6dbccf5932..8712c580c0 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2589,6 +2589,13 @@ opentracing: # #run_background_tasks_on: worker1 +# A shared secret used by the replication APIs to authenticate HTTP requests +# from workers. +# +# By default this is unused and traffic is not authenticated. +# +#worker_replication_secret: "" + # Configuration for Redis when using workers. This *must* be enabled when # using workers (unless using old style direct TCP configuration). diff --git a/docs/workers.md b/docs/workers.md index c53d1bd2ff..efe97af31a 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -89,7 +89,8 @@ shared configuration file. Normally, only a couple of changes are needed to make an existing configuration file suitable for use with workers. First, you need to enable an "HTTP replication listener" for the main process; and secondly, you need to enable redis-based -replication. For example: +replication. Optionally, a shared secret can be used to authenticate HTTP +traffic between workers. For example: ```yaml @@ -103,6 +104,9 @@ listeners: resources: - names: [replication] +# Add a random shared secret to authenticate traffic. +worker_replication_secret: "" + redis: enabled: true ``` diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 57ab097eba..7ca9efec52 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -85,6 +85,9 @@ class WorkerConfig(Config): # The port on the main synapse for HTTP replication endpoint self.worker_replication_http_port = config.get("worker_replication_http_port") + # The shared secret used for authentication when connecting to the main synapse. + self.worker_replication_secret = config.get("worker_replication_secret", None) + self.worker_name = config.get("worker_name", self.worker_app) self.worker_main_http_uri = config.get("worker_main_http_uri", None) @@ -185,6 +188,13 @@ class WorkerConfig(Config): # data). If not provided this defaults to the main process. # #run_background_tasks_on: worker1 + + # A shared secret used by the replication APIs to authenticate HTTP requests + # from workers. + # + # By default this is unused and traffic is not authenticated. + # + #worker_replication_secret: "" """ def read_arguments(self, args): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2b3972cb14..1492ac922c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): assert self.METHOD in ("PUT", "POST", "GET") + self._replication_secret = None + if hs.config.worker.worker_replication_secret: + self._replication_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request) -> None: + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if len(auth_headers) > 1: + raise RuntimeError("Too many Authorization headers.") + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._replication_secret == received_secret: + # Success! + return + + raise RuntimeError("Invalid Authorization header.") + @abc.abstractmethod async def _serialize_payload(**kwargs): """Static method that is called when creating a request. @@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + replication_secret = None + if hs.config.worker.worker_replication_secret: + replication_secret = hs.config.worker.worker_replication_secret.encode( + "ascii" + ) + @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() async def send_request(instance_name="master", **kwargs): @@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # the master, and so whether we should clean up or not. while True: headers = {} # type: Dict[bytes, List[bytes]] + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = await request_func(uri, data, headers=headers) @@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): """ url_args = list(self.PATH_ARGS) - handler = self._handle_request method = self.METHOD if self.CACHE: - handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, + method, [pattern], self._check_auth_and_handle, self.__class__.__name__, ) - def _cached_handler(self, request, txn_id, **kwargs): + def _check_auth_and_handle(self, request, **kwargs): """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We just use the txn_id here, but we probably also want to use the # other PATH_ARGS as well. - assert self.CACHE + # Check the authorization headers before handling the request. + if self._replication_secret: + self._check_auth(request) + + if self.CACHE: + txn_id = kwargs.pop("txn_id") + + return self.response_cache.wrap( + txn_id, self._handle_request, request, **kwargs + ) - return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) + return self._handle_request(request, **kwargs) diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py new file mode 100644 index 0000000000..fe9e4d5f9a --- /dev/null +++ b/tests/replication/test_auth.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Tuple + +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha import register + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, make_request +from tests.unittest import override_config + +logger = logging.getLogger(__name__) + + +class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): + """Test the authentication of HTTP calls between workers.""" + + servlets = [register.register_servlets] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # This isn't a real configuration option but is used to provide the main + # homeserver and worker homeserver different options. + main_replication_secret = config.pop("main_replication_secret", None) + if main_replication_secret: + config["worker_replication_secret"] = main_replication_secret + return self.setup_test_homeserver(config=config) + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + + return config + + def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]: + """Run the actual test: + + 1. Create a worker homeserver. + 2. Start registration by providing a user/password. + 3. Complete registration by providing dummy auth (this hits the main synapse). + 4. Return the final request. + + """ + worker_hs = self.make_worker_hs("synapse.app.client_reader") + site = self._hs_to_site[worker_hs] + + request_1, channel_1 = make_request( + self.reactor, + site, + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + return make_request( + self.reactor, + site, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, + ) + + def test_no_auth(self): + """With no authentication the request should finish. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + @override_config({"main_replication_secret": "my-secret"}) + def test_missing_auth(self): + """If the main process expects a secret that is not provided, an error results. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 500) + + @override_config( + { + "main_replication_secret": "my-secret", + "worker_replication_secret": "wrong-secret", + } + ) + def test_unauthorized(self): + """If the main process receives the wrong secret, an error results. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 500) + + @override_config({"worker_replication_secret": "my-secret"}) + def test_authorized(self): + """The request should finish when the worker provides the authentication header. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 96801db473..fdaad3d8ad 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -14,27 +14,20 @@ # limitations under the License. import logging -from synapse.api.constants import LoginType from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker from tests.server import FakeChannel, make_request logger = logging.getLogger(__name__) class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): - """Base class for tests of the replication streams""" + """Test using one or more client readers for registration.""" servlets = [register.register_servlets] - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" -- cgit 1.5.1 From 1f3748f03398f8f91ec5121312aa79dd58306ec1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Dec 2020 10:00:08 -0500 Subject: Do not raise a 500 exception when previewing empty media. (#8883) --- changelog.d/8883.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 6 +++++- tests/test_preview.py | 27 ++++++++++++++++----------- 3 files changed, 22 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8883.bugfix (limited to 'tests') diff --git a/changelog.d/8883.bugfix b/changelog.d/8883.bugfix new file mode 100644 index 0000000000..6137fc5b2b --- /dev/null +++ b/changelog.d/8883.bugfix @@ -0,0 +1 @@ +Fix a 500 error when attempting to preview an empty HTML file. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index dce6c4d168..1082389d9b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None): +def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: + # If there's no body, nothing useful is going to be found. + if not body: + return {} + from lxml import etree try: diff --git a/tests/test_preview.py b/tests/test_preview.py index 7f67ee9e1f..a883d707df 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø lies in Northern Norway. The municipality has a population of" " (2015) 72,066, but with an annual influx of students it has over 75,000" @@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase): ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): html = """ @@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): html = """ @@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals( + self.assertEqual( og, { "og:title": "Foo", @@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): html = """ @@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): html = """ @@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): html = """ @@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_empty(self): + html = "" + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {}) -- cgit 1.5.1 From ff1f0ee09472b554832fb39952f389d01a4233ac Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Mon, 7 Dec 2020 19:13:07 +0000 Subject: Call set_avatar_url with target_user, not user_id (#8872) * Call set_avatar_url with target_user, not user_id Fixes https://github.com/matrix-org/synapse/issues/8871 * Create 8872.bugfix * Update synapse/rest/admin/users.py Co-authored-by: Patrick Cloke * Testing * Update changelog.d/8872.bugfix Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Patrick Cloke Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/8872.bugfix | 1 + synapse/rest/admin/users.py | 4 ++-- tests/rest/admin/test_user.py | 7 ++++++- 3 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8872.bugfix (limited to 'tests') diff --git a/changelog.d/8872.bugfix b/changelog.d/8872.bugfix new file mode 100644 index 0000000000..ed00b70a0f --- /dev/null +++ b/changelog.d/8872.bugfix @@ -0,0 +1 @@ +Fix a bug where `PUT /_synapse/admin/v2/users/` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 90940ff185..88cba369f5 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -320,9 +320,9 @@ class UserRestServletV2(RestServlet): data={}, ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( - user_id, requester, body["avatar_url"], True + target_user, requester, body["avatar_url"], True ) ret = await self.admin_handler.get_user(target_user) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 35c546aa69..ba1438cdc7 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -561,7 +561,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": True, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": None, + "avatar_url": "mxc://fibble/wibble", } ) @@ -578,6 +578,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user request, channel = self.make_request( @@ -592,6 +593,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(True, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): """ @@ -606,6 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", } ) @@ -622,6 +625,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user request, channel = self.make_request( @@ -636,6 +640,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(False, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} -- cgit 1.5.1 From cd9e72b185e36ac3f18ab1fe567fdeee87bfcb8e Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Tue, 8 Dec 2020 16:51:03 -0600 Subject: Add X-Robots-Tag header to stop crawlers from indexing media (#8887) Fixes / related to: https://github.com/matrix-org/synapse/issues/6533 This should do essentially the same thing as a robots.txt file telling robots to not index the media repo. https://developers.google.com/search/reference/robots_meta_tag Signed-off-by: Aaron Raimist --- changelog.d/8887.feature | 1 + synapse/rest/media/v1/_base.py | 5 +++++ tests/rest/media/v1/test_media_storage.py | 13 +++++++++++++ 3 files changed, 19 insertions(+) create mode 100644 changelog.d/8887.feature (limited to 'tests') diff --git a/changelog.d/8887.feature b/changelog.d/8887.feature new file mode 100644 index 0000000000..729eb1f1ea --- /dev/null +++ b/changelog.d/8887.feature @@ -0,0 +1 @@ +Add `X-Robots-Tag` header to stop web crawlers from indexing media. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 67aa993f19..47c2b44bff 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name): request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") request.setHeader(b"Content-Length", b"%d" % (file_size,)) + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 4c749f1a61..6f0677d335 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -362,3 +362,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): "error": "Not found [b'example.com', b'12345']", }, ) + + def test_x_robots_tag_header(self): + """ + Tests that the `X-Robots-Tag` header is present, which informs web crawlers + to not index, archive, or follow links in media. + """ + channel = self._req(b"inline; filename=out" + self.test_image.extension) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"X-Robots-Tag"), + [b"noindex, nofollow, noarchive, noimageindex"], + ) -- cgit 1.5.1 From 6ff34e00d9bd4e8b0b91347a359b17abeaa22e5a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Dec 2020 12:23:30 -0500 Subject: Skip the SAML tests if xmlsec1 isn't available. (#8905) --- changelog.d/8905.misc | 1 + tests/handlers/test_saml.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 changelog.d/8905.misc (limited to 'tests') diff --git a/changelog.d/8905.misc b/changelog.d/8905.misc new file mode 100644 index 0000000000..a9a11a2303 --- /dev/null +++ b/changelog.d/8905.misc @@ -0,0 +1 @@ +Skip the SAML tests if the requirements (`pysaml2` and `xmlsec1`) aren't available. diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 45dc17aba5..d21e5588ca 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -19,6 +19,24 @@ from synapse.handlers.sso import MappingException from tests.unittest import HomeserverTestCase, override_config +# Check if we have the dependencies to run the tests. +try: + import saml2.config + from saml2.sigver import SigverError + + has_saml2 = True + + # pysaml2 can be installed and imported, but might not be able to find xmlsec1. + config = saml2.config.SPConfig() + try: + config.load({"metadata": {}}) + has_xmlsec1 = True + except SigverError: + has_xmlsec1 = False +except ImportError: + has_saml2 = False + has_xmlsec1 = False + # These are a few constants that are used as config parameters in the tests. BASE_URL = "https://synapse/" @@ -86,6 +104,11 @@ class SamlHandlerTestCase(HomeserverTestCase): return hs + if not has_saml2: + skip = "Requires pysaml2" + elif not has_xmlsec1: + skip = "Requires xmlsec1" + def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) -- cgit 1.5.1 From 344ab0b53abc0291d79882f8bdc1a853f7495ed4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Dec 2020 13:56:06 -0500 Subject: Default to blacklisting reserved IP ranges and add a whitelist. (#8870) This defaults `ip_range_blacklist` to reserved IP ranges and also adds an `ip_range_whitelist` setting to override it. --- INSTALL.md | 7 ++- UPGRADE.rst | 21 ++++++++ changelog.d/8821.bugfix | 2 +- changelog.d/8870.bugfix | 1 + docs/sample_config.yaml | 66 ++++++++++++++++-------- synapse/config/federation.py | 59 ++++------------------ synapse/config/repository.py | 20 +++----- synapse/config/server.py | 80 ++++++++++++++++++++++++++++++ synapse/server.py | 3 +- tests/replication/test_multi_media_repo.py | 2 +- 10 files changed, 172 insertions(+), 89 deletions(-) create mode 100644 changelog.d/8870.bugfix (limited to 'tests') diff --git a/INSTALL.md b/INSTALL.md index eaeb690092..eb5f506de9 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -557,10 +557,9 @@ This is critical from a security perspective to stop arbitrary Matrix users spidering 'internal' URLs on your network. At the very least we recommend that your loopback and RFC1918 IP addresses are blacklisted. -This also requires the optional `lxml` and `netaddr` python dependencies to be -installed. This in turn requires the `libxml2` library to be available - on -Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for -your OS. +This also requires the optional `lxml` python dependency to be installed. This +in turn requires the `libxml2` library to be available - on Debian/Ubuntu this +means `apt-get install libxml2-dev`, or equivalent for your OS. # Troubleshooting Installation diff --git a/UPGRADE.rst b/UPGRADE.rst index 6825b567e9..54a40bd42f 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,27 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.25.0 +==================== + +Blacklisting IP ranges +---------------------- + +Synapse v1.25.0 includes new settings, ``ip_range_blacklist`` and +``ip_range_whitelist``, for controlling outgoing requests from Synapse for federation, +identity servers, push, and for checking key validity for third-party invite events. +The previous setting, ``federation_ip_range_blacklist``, is deprecated. The new +``ip_range_blacklist`` defaults to private IP ranges if it is not defined. + +If you have never customised ``federation_ip_range_blacklist`` it is recommended +that you remove that setting. + +If you have customised ``federation_ip_range_blacklist`` you should update the +setting name to ``ip_range_blacklist``. + +If you have a custom push server that is reached via private IP space you may +need to customise ``ip_range_blacklist`` or ``ip_range_whitelist``. + Upgrading to v1.24.0 ==================== diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix index 8ddfbf31ce..39f53174ad 100644 --- a/changelog.d/8821.bugfix +++ b/changelog.d/8821.bugfix @@ -1 +1 @@ -Apply the `federation_ip_range_blacklist` to push and key revocation requests. +Apply an IP range blacklist to push and key revocation requests. diff --git a/changelog.d/8870.bugfix b/changelog.d/8870.bugfix new file mode 100644 index 0000000000..39f53174ad --- /dev/null +++ b/changelog.d/8870.bugfix @@ -0,0 +1 @@ +Apply an IP range blacklist to push and key revocation requests. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 68c8f4f0e2..f196781c1c 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -144,6 +144,35 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false +# Prevent outgoing requests from being sent to the following blacklisted IP address +# CIDR ranges. If this option is not specified then it defaults to private IP +# address ranges (see the example below). +# +# The blacklist applies to the outbound requests for federation, identity servers, +# push servers, and for checking key validity for third-party invite events. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +# This option replaces federation_ip_range_blacklist in Synapse v1.25.0. +# +#ip_range_blacklist: +# - '127.0.0.0/8' +# - '10.0.0.0/8' +# - '172.16.0.0/12' +# - '192.168.0.0/16' +# - '100.64.0.0/10' +# - '192.0.0.0/24' +# - '169.254.0.0/16' +# - '198.18.0.0/15' +# - '192.0.2.0/24' +# - '198.51.100.0/24' +# - '203.0.113.0/24' +# - '224.0.0.0/4' +# - '::1/128' +# - 'fe80::/10' +# - 'fc00::/7' + # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -642,28 +671,17 @@ acme: # - nyc.example.com # - syd.example.com -# Prevent outgoing requests from being sent to the following blacklisted IP address -# CIDR ranges. If this option is not specified, or specified with an empty list, -# no IP range blacklist will be enforced. +# List of IP address CIDR ranges that should be allowed for federation, +# identity servers, push servers, and for checking key validity for +# third-party invite events. This is useful for specifying exceptions to +# wide-ranging blacklisted target IP ranges - e.g. for communication with +# a push server only visible in your network. # -# The blacklist applies to the outbound requests for federation, identity servers, -# push servers, and for checking key validitity for third-party invite events. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -# This option replaces federation_ip_range_blacklist in Synapse v1.24.0. +# This whitelist overrides ip_range_blacklist and defaults to an empty +# list. # -ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' +#ip_range_whitelist: +# - '192.168.1.1' # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound @@ -955,9 +973,15 @@ media_store_path: "DATADIR/media_store" # - '172.16.0.0/12' # - '192.168.0.0/16' # - '100.64.0.0/10' +# - '192.0.0.0/24' # - '169.254.0.0/16' +# - '198.18.0.0/15' +# - '192.0.2.0/24' +# - '198.51.100.0/24' +# - '203.0.113.0/24' +# - '224.0.0.0/4' # - '::1/128' -# - 'fe80::/64' +# - 'fe80::/10' # - 'fc00::/7' # List of IP address CIDR ranges that the URL preview spider is allowed diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 27ccf61c3c..a03a419e23 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -12,12 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional -from netaddr import IPSet - -from synapse.config._base import Config, ConfigError +from synapse.config._base import Config from synapse.config._util import validate_config @@ -36,31 +33,6 @@ class FederationConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True - ip_range_blacklist = config.get("ip_range_blacklist", []) - - # Attempt to create an IPSet from the given ranges - try: - self.ip_range_blacklist = IPSet(ip_range_blacklist) - except Exception as e: - raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e) - # Always blacklist 0.0.0.0, :: - self.ip_range_blacklist.update(["0.0.0.0", "::"]) - - # The federation_ip_range_blacklist is used for backwards-compatibility - # and only applies to federation and identity servers. If it is not given, - # default to ip_range_blacklist. - federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", ip_range_blacklist - ) - try: - self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - federation_metrics_domains = config.get("federation_metrics_domains") or [] validate_config( _METRICS_FOR_DOMAINS_SCHEMA, @@ -84,28 +56,17 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # Prevent outgoing requests from being sent to the following blacklisted IP address - # CIDR ranges. If this option is not specified, or specified with an empty list, - # no IP range blacklist will be enforced. - # - # The blacklist applies to the outbound requests for federation, identity servers, - # push servers, and for checking key validitity for third-party invite events. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) + # List of IP address CIDR ranges that should be allowed for federation, + # identity servers, push servers, and for checking key validity for + # third-party invite events. This is useful for specifying exceptions to + # wide-ranging blacklisted target IP ranges - e.g. for communication with + # a push server only visible in your network. # - # This option replaces federation_ip_range_blacklist in Synapse v1.24.0. + # This whitelist overrides ip_range_blacklist and defaults to an empty + # list. # - ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' + #ip_range_whitelist: + # - '192.168.1.1' # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 17ce9145ef..850ac3ebd6 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -17,6 +17,9 @@ import os from collections import namedtuple from typing import Dict, List +from netaddr import IPSet + +from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module @@ -184,9 +187,6 @@ class ContentRepositoryConfig(Config): "to work" ) - # netaddr is a dependency for url_preview - from netaddr import IPSet - self.url_preview_ip_range_blacklist = IPSet( config["url_preview_ip_range_blacklist"] ) @@ -215,6 +215,10 @@ class ContentRepositoryConfig(Config): # strip final NL formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] + ip_range_blacklist = "\n".join( + " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST + ) + return ( r""" ## Media Store ## @@ -285,15 +289,7 @@ class ContentRepositoryConfig(Config): # you uncomment the following list as a starting point. # #url_preview_ip_range_blacklist: - # - '127.0.0.0/8' - # - '10.0.0.0/8' - # - '172.16.0.0/12' - # - '192.168.0.0/16' - # - '100.64.0.0/10' - # - '169.254.0.0/16' - # - '::1/128' - # - 'fe80::/64' - # - 'fc00::/7' +%(ip_range_blacklist)s # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. diff --git a/synapse/config/server.py b/synapse/config/server.py index 85aa49c02d..f3815e5add 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml +from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name @@ -39,6 +40,34 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] +DEFAULT_IP_RANGE_BLACKLIST = [ + # Localhost + "127.0.0.0/8", + # Private networks. + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + # Carrier grade NAT. + "100.64.0.0/10", + # Address registry. + "192.0.0.0/24", + # Link-local networks. + "169.254.0.0/16", + # Testing networks. + "198.18.0.0/15", + "192.0.2.0/24", + "198.51.100.0/24", + "203.0.113.0/24", + # Multicast. + "224.0.0.0/4", + # Localhost + "::1/128", + # Link-local addresses. + "fe80::/10", + # Unique local addresses. + "fc00::/7", +] + DEFAULT_ROOM_VERSION = "6" ROOM_COMPLEXITY_TOO_GREAT = ( @@ -256,6 +285,38 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) + ip_range_blacklist = config.get( + "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST + ) + + # Attempt to create an IPSet from the given ranges + try: + self.ip_range_blacklist = IPSet(ip_range_blacklist) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e + # Always blacklist 0.0.0.0, :: + self.ip_range_blacklist.update(["0.0.0.0", "::"]) + + try: + self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ())) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e + + # The federation_ip_range_blacklist is used for backwards-compatibility + # and only applies to federation and identity servers. If it is not given, + # default to ip_range_blacklist. + federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", ip_range_blacklist + ) + try: + self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist." + ) from e + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -561,6 +622,10 @@ class ServerConfig(Config): def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs ): + ip_range_blacklist = "\n".join( + " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST + ) + _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 @@ -752,6 +817,21 @@ class ServerConfig(Config): # #enable_search: false + # Prevent outgoing requests from being sent to the following blacklisted IP address + # CIDR ranges. If this option is not specified then it defaults to private IP + # address ranges (see the example below). + # + # The blacklist applies to the outbound requests for federation, identity servers, + # push servers, and for checking key validity for third-party invite events. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. + # + #ip_range_blacklist: +%(ip_range_blacklist)s + # List of ports that Synapse should listen on, their purpose and their # configuration. # diff --git a/synapse/server.py b/synapse/server.py index 9af759626e..043810ad31 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -370,10 +370,11 @@ class HomeServer(metaclass=abc.ABCMeta): def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: """ An HTTP client that uses configured HTTP(S) proxies and blacklists IPs - based on the IP range blacklist. + based on the IP range blacklist/whitelist. """ return SimpleHttpClient( self, + ip_whitelist=self.config.ip_range_whitelist, ip_blacklist=self.config.ip_range_blacklist, http_proxy=os.getenvb(b"http_proxy"), https_proxy=os.getenvb(b"HTTPS_PROXY"), diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 48b574ccbe..83afd9fd2f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.user_id = self.register_user("user", "pass") self.access_token = self.login("user", "pass") - self.reactor.lookups["example.com"] = "127.0.0.2" + self.reactor.lookups["example.com"] = "1.2.3.4" def default_config(self): conf = super().default_config() -- cgit 1.5.1 From 1d55c7b56730f0aeb8a620a22d0994f1dc735dfe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Dec 2020 10:17:49 +0000 Subject: Don't ratelimit autojoining of rooms (#8921) Fixes #8866 --- changelog.d/8921.bugfix | 1 + synapse/handlers/room.py | 5 ++++- synapse/handlers/room_member.py | 23 +++++++++++++---------- tests/rest/client/v1/test_rooms.py | 16 ++++++++++++++++ 4 files changed, 34 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8921.bugfix (limited to 'tests') diff --git a/changelog.d/8921.bugfix b/changelog.d/8921.bugfix new file mode 100644 index 0000000000..7f6f0b8a76 --- /dev/null +++ b/changelog.d/8921.bugfix @@ -0,0 +1 @@ +Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 930047e730..82fb72b381 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -440,6 +440,7 @@ class RoomCreationHandler(BaseHandler): invite_list=[], initial_state=initial_state, creation_content=creation_content, + ratelimit=False, ) # Transfer membership events @@ -735,6 +736,7 @@ class RoomCreationHandler(BaseHandler): room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, + ratelimit=ratelimit, ) if "name" in config: @@ -838,6 +840,7 @@ class RoomCreationHandler(BaseHandler): room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, + ratelimit: bool = True, ) -> int: """Sends the initial events into a new room. @@ -884,7 +887,7 @@ class RoomCreationHandler(BaseHandler): creator.user, room_id, "join", - ratelimit=False, + ratelimit=ratelimit, content=creator_join_profile, ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c002886324..d85110a35e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -203,7 +203,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Only rate-limit if the user actually joined the room, otherwise we'll end # up blocking profile updates. - if newly_joined: + if newly_joined and ratelimit: time_now_s = self.clock.time() ( allowed, @@ -488,17 +488,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - time_now_s = self.clock.time() - ( - allowed, - time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + if ratelimit: + time_now_s = self.clock.time() + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action( + requester, ) + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e67de41c18..55d872f0ee 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -26,6 +26,7 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus +from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account from synapse.types import JsonDict, RoomAlias, UserID @@ -625,6 +626,7 @@ class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" servlets = [ + admin.register_servlets, profile.register_servlets, room.register_servlets, ] @@ -703,6 +705,20 @@ class RoomJoinRatelimitTestCase(RoomBase): request, channel = self.make_request("POST", path % room_id, {}) self.assertEquals(channel.code, 200) + @unittest.override_config( + { + "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, + "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], + "autocreate_auto_join_rooms": True, + }, + ) + def test_autojoin_rooms(self): + user_id = self.register_user("testuser", "password") + + # Check that the new user successfully joined the four rooms + rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 4) + class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ -- cgit 1.5.1 From 0a34cdfc6682c2654c745c4d7c2f5ffd1865dbc8 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 11 Dec 2020 11:42:47 +0100 Subject: Add number of local devices to Room Details Admin API (#8886) --- changelog.d/8886.feature | 1 + docs/admin_api/rooms.md | 24 +++++++++------- synapse/rest/admin/rooms.py | 48 ++++++++++++++++++++----------- synapse/storage/databases/main/devices.py | 32 +++++++++++++++++++++ tests/rest/admin/test_room.py | 34 ++++++++++++++++++++++ tests/storage/test_devices.py | 26 +++++++++++++++++ 6 files changed, 138 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8886.feature (limited to 'tests') diff --git a/changelog.d/8886.feature b/changelog.d/8886.feature new file mode 100644 index 0000000000..9e446f28bd --- /dev/null +++ b/changelog.d/8886.feature @@ -0,0 +1 @@ +Add number of local devices to Room Details Admin API. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 3ac21b5cae..d7b1740fe3 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -87,7 +87,7 @@ GET /_synapse/admin/v1/rooms Response: -``` +```jsonc { "rooms": [ { @@ -139,7 +139,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM Response: -``` +```json { "rooms": [ { @@ -174,7 +174,7 @@ GET /_synapse/admin/v1/rooms?order_by=size Response: -``` +```jsonc { "rooms": [ { @@ -230,14 +230,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100 Response: -``` +```jsonc { "rooms": [ { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", "canonical_alias": "#musictheory:matrix.org", - "joined_members": 127 + "joined_members": 127, "joined_local_members": 2, "version": "1", "creator": "@foo:matrix.org", @@ -254,7 +254,7 @@ Response: "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", "name": "weechat-matrix", "canonical_alias": "#weechat-matrix:termina.org.uk", - "joined_members": 137 + "joined_members": 137, "joined_local_members": 20, "version": "4", "creator": "@foo:termina.org.uk", @@ -289,6 +289,7 @@ The following fields are possible in the JSON response body: * `canonical_alias` - The canonical (main) alias address of the room. * `joined_members` - How many users are currently in the room. * `joined_local_members` - How many local users are currently in the room. +* `joined_local_devices` - How many local devices are currently in the room. * `version` - The version of the room as a string. * `creator` - The `user_id` of the room creator. * `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. @@ -311,15 +312,16 @@ GET /_synapse/admin/v1/rooms/ Response: -``` +```json { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", "avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV", "topic": "Theory, Composition, Notation, Analysis", "canonical_alias": "#musictheory:matrix.org", - "joined_members": 127 + "joined_members": 127, "joined_local_members": 2, + "joined_local_devices": 2, "version": "1", "creator": "@foo:matrix.org", "encryption": null, @@ -353,13 +355,13 @@ GET /_synapse/admin/v1/rooms//members Response: -``` +```json { "members": [ "@foo:matrix.org", "@bar:matrix.org", - "@foobar:matrix.org - ], + "@foobar:matrix.org" + ], "total": 3 } ``` diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 25f89e4685..b902af8028 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -14,7 +14,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -25,13 +25,17 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -45,12 +49,14 @@ class ShutdownRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/shutdown_room/(?P[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -86,13 +92,15 @@ class DeleteRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -146,12 +154,12 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -236,19 +244,24 @@ class RoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room_with_stats(room_id) if not ret: raise NotFoundError("Room not found") - return 200, ret + members = await self.store.get_users_in_room(room_id) + ret["joined_local_devices"] = await self.store.count_devices_by_users(members) + + return (200, ret) class RoomMembersRestServlet(RestServlet): @@ -258,12 +271,14 @@ class RoomMembersRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) @@ -280,14 +295,16 @@ class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_member_handler = hs.get_room_member_handler() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() - async def on_POST(self, request, room_identifier): + async def on_POST( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -314,7 +331,6 @@ class JoinRoomAliasServlet(RestServlet): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index dfb4f87b8f..9097677648 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore): self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + """Retrieve number of all devices of given users. + Only returns number of devices that are not marked as hidden. + + Args: + user_ids: The IDs of the users which owns devices + Returns: + Number of devices of this users. + """ + + def count_devices_by_users_txn(txn, user_ids): + sql = """ + SELECT count(*) + FROM devices + WHERE + hidden = '0' AND + """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", user_ids + ) + + txn.execute(sql + clause, args) + return txn.fetchone()[0] + + if not user_ids: + return 0 + + return await self.db_pool.runInteraction( + "count_devices_by_users", count_devices_by_users_txn, user_ids + ) + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 46933a0493..9c100050d2 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1084,6 +1084,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("canonical_alias", channel.json_body) self.assertIn("joined_members", channel.json_body) self.assertIn("joined_local_members", channel.json_body) + self.assertIn("joined_local_devices", channel.json_body) self.assertIn("version", channel.json_body) self.assertIn("creator", channel.json_body) self.assertIn("encryption", channel.json_body) @@ -1096,6 +1097,39 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) + def test_single_room_devices(self): + """Test that `joined_local_devices` can be requested correctly""" + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["joined_local_devices"]) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(2, channel.json_body["joined_local_devices"]) + + # leave room + self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok) + self.helper.leave(room_id_1, user_1, tok=user_tok_1) + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["joined_local_devices"]) + def test_room_members(self): """Test that room members can be requested correctly""" # Create two test rooms diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index ecb00f4e02..dabc1c5f09 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -79,6 +79,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase): res["device2"], ) + @defer.inlineCallbacks + def test_count_devices_by_users(self): + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users()) + self.assertEqual(0, res) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) + self.assertEqual(0, res) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) + self.assertEqual(2, res) + + res = yield defer.ensureDeferred( + self.store.count_devices_by_users(["user_id", "user_id2"]) + ) + self.assertEqual(3, res) + @defer.inlineCallbacks def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] -- cgit 1.5.1 From 3af0672350965b4fddf3aad2904795bbd0199c30 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 11 Dec 2020 07:25:01 -0500 Subject: Improve tests for structured logging. (#8916) --- changelog.d/8916.misc | 1 + tests/logging/test_terse_json.py | 73 +++++++++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8916.misc (limited to 'tests') diff --git a/changelog.d/8916.misc b/changelog.d/8916.misc new file mode 100644 index 0000000000..c71ef480e6 --- /dev/null +++ b/changelog.d/8916.misc @@ -0,0 +1 @@ +Improve structured logging tests. diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 73f469b802..f6e7e5fdaa 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -18,30 +18,35 @@ import logging from io import StringIO from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter +from synapse.logging.context import LoggingContext, LoggingContextFilter from tests.logging import LoggerCleanupMixin from tests.unittest import TestCase class TerseJsonTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.output = StringIO() + + def get_log_line(self): + # One log message, with a single trailing newline. + data = self.output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + return json.loads(logs[0]) + def test_terse_json_output(self): """ The Terse JSON formatter converts log messages to JSON. """ - output = StringIO() - - handler = logging.StreamHandler(output) + handler = logging.StreamHandler(self.output) handler.setFormatter(TerseJsonFormatter()) logger = self.get_logger(handler) logger.info("Hello there, %s!", "wally") - # One log message, with a single trailing newline. - data = output.getvalue() - logs = data.splitlines() - self.assertEqual(len(logs), 1) - self.assertEqual(data.count("\n"), 1) - log = json.loads(logs[0]) + log = self.get_log_line() # The terse logger should give us these keys. expected_log_keys = [ @@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): """ Additional information can be included in the structured logging. """ - output = StringIO() - - handler = logging.StreamHandler(output) + handler = logging.StreamHandler(self.output) handler.setFormatter(TerseJsonFormatter()) logger = self.get_logger(handler) @@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True} ) - # One log message, with a single trailing newline. - data = output.getvalue() - logs = data.splitlines() - self.assertEqual(len(logs), 1) - self.assertEqual(data.count("\n"), 1) - log = json.loads(logs[0]) + log = self.get_log_line() # The terse logger should give us these keys. expected_log_keys = [ @@ -96,26 +94,47 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): """ The Terse JSON formatter converts log messages to JSON. """ - output = StringIO() - - handler = logging.StreamHandler(output) + handler = logging.StreamHandler(self.output) handler.setFormatter(JsonFormatter()) logger = self.get_logger(handler) logger.info("Hello there, %s!", "wally") - # One log message, with a single trailing newline. - data = output.getvalue() - logs = data.splitlines() - self.assertEqual(len(logs), 1) - self.assertEqual(data.count("\n"), 1) - log = json.loads(logs[0]) + log = self.get_log_line() + + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") + + def test_with_context(self): + """ + The logging context should be added to the JSON response. + """ + handler = logging.StreamHandler(self.output) + handler.setFormatter(JsonFormatter()) + handler.addFilter(LoggingContextFilter(request="")) + logger = self.get_logger(handler) + + with LoggingContext() as context_one: + context_one.request = "test" + logger.info("Hello there, %s!", "wally") + + log = self.get_log_line() # The terse logger should give us these keys. expected_log_keys = [ "log", "level", "namespace", + "request", + "scope", ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") + self.assertEqual(log["request"], "test") + self.assertIsNone(log["scope"]) -- cgit 1.5.1 From f14428b25c37e44675edac4a80d7bd1e47112586 Mon Sep 17 00:00:00 2001 From: David Teller Date: Fri, 11 Dec 2020 20:05:15 +0100 Subject: Allow spam-checker modules to be provide async methods. (#8890) Spam checker modules can now provide async methods. This is implemented in a backwards-compatible manner. --- changelog.d/8890.feature | 1 + docs/spam_checker.md | 19 ++++++--- synapse/events/spamcheck.py | 55 +++++++++++++++++++-------- synapse/federation/federation_base.py | 7 +++- synapse/handlers/auth.py | 8 ++-- synapse/handlers/directory.py | 6 ++- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/receipts.py | 7 +--- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 4 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/user_directory.py | 10 ++--- synapse/metrics/background_process_metrics.py | 9 +---- synapse/rest/media/v1/storage_provider.py | 16 +++----- synapse/server.py | 2 +- synapse/util/async_helpers.py | 8 ++-- synapse/util/distributor.py | 7 +--- tests/handlers/test_user_directory.py | 4 +- 19 files changed, 98 insertions(+), 73 deletions(-) create mode 100644 changelog.d/8890.feature (limited to 'tests') diff --git a/changelog.d/8890.feature b/changelog.d/8890.feature new file mode 100644 index 0000000000..97aa72a76e --- /dev/null +++ b/changelog.d/8890.feature @@ -0,0 +1 @@ +Spam-checkers may now define their methods as `async`. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 7fc08f1b70..5b4f6428e6 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -22,6 +22,8 @@ well as some specific methods: * `user_may_create_room` * `user_may_create_room_alias` * `user_may_publish_room` +* `check_username_for_spam` +* `check_registration_for_spam` The details of the each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. @@ -32,28 +34,33 @@ call back into the homeserver internals. ### Example ```python +from synapse.spam_checker_api import RegistrationBehaviour + class ExampleSpamChecker: def __init__(self, config, api): self.config = config self.api = api - def check_event_for_spam(self, foo): + async def check_event_for_spam(self, foo): return False # allow all events - def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): return True # allow all invites - def user_may_create_room(self, userid): + async def user_may_create_room(self, userid): return True # allow all room creations - def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias(self, userid, room_alias): return True # allow all room aliases - def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): return False # allow all usernames + + async def check_registration_for_spam(self, email_threepid, username, request_info): + return RegistrationBehaviour.ALLOW # allow all registrations ``` ## Configuration diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 936896656a..e7e3a7b9a4 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,10 +15,11 @@ # limitations under the License. import inspect -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import synapse.events @@ -39,7 +40,9 @@ class SpamChecker: else: self.spam_checkers.append(module(config=config)) - def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: + async def check_event_for_spam( + self, event: "synapse.events.EventBase" + ) -> Union[bool, str]: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -50,15 +53,16 @@ class SpamChecker: event: the event to be checked Returns: - True if the event is spammy. + True or a string if the event is spammy. If a string is returned it + will be used as the error message returned to the user. """ for spam_checker in self.spam_checkers: - if spam_checker.check_event_for_spam(event): + if await maybe_awaitable(spam_checker.check_event_for_spam(event)): return True return False - def user_may_invite( + async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str ) -> bool: """Checks if a given user may send an invite @@ -75,14 +79,18 @@ class SpamChecker: """ for spam_checker in self.spam_checkers: if ( - spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + await maybe_awaitable( + spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) + ) is False ): return False return True - def user_may_create_room(self, userid: str) -> bool: + async def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. @@ -94,12 +102,15 @@ class SpamChecker: True if the user may create a room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room(userid) is False: + if ( + await maybe_awaitable(spam_checker.user_may_create_room(userid)) + is False + ): return False return True - def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: + async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: """Checks if a given user may create a room alias If this method returns false, the association request will be rejected. @@ -112,12 +123,17 @@ class SpamChecker: True if the user may create a room alias, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room_alias(userid, room_alias) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_create_room_alias(userid, room_alias) + ) + is False + ): return False return True - def user_may_publish_room(self, userid: str, room_id: str) -> bool: + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: """Checks if a given user may publish a room to the directory If this method returns false, the publish request will be rejected. @@ -130,12 +146,17 @@ class SpamChecker: True if the user may publish the room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_publish_room(userid, room_id) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_publish_room(userid, room_id) + ) + is False + ): return False return True - def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in @@ -157,12 +178,12 @@ class SpamChecker: if checker: # Make a copy of the user profile object to ensure the spam checker # cannot modify it. - if checker(user_profile.copy()): + if await maybe_awaitable(checker(user_profile.copy())): return True return False - def check_registration_for_spam( + async def check_registration_for_spam( self, email_threepid: Optional[dict], username: Optional[str], @@ -185,7 +206,9 @@ class SpamChecker: # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = checker(email_threepid, username, request_info) + behaviour = await maybe_awaitable( + checker(email_threepid, username, request_info) + ) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 38aa47963f..383737520a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -78,6 +78,7 @@ class FederationBase: ctx = current_context() + @defer.inlineCallbacks def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): @@ -105,7 +106,11 @@ class FederationBase: ) return redacted_event - if self.spam_checker.check_event_for_spam(pdu): + result = yield defer.ensureDeferred( + self.spam_checker.check_event_for_spam(pdu) + ) + + if result: logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 62f98dabc0..8deec4cd0c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import time import unicodedata @@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils +from synapse.util.async_helpers import maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email @@ -1639,6 +1639,6 @@ class PasswordProvider: # This might return an awaitable, if it does block the log out # until it completes. - result = g(user_id=user_id, device_id=device_id, access_token=access_token,) - if inspect.isawaitable(result): - await result + await maybe_awaitable( + g(user_id=user_id, device_id=device_id, access_token=access_token,) + ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ad5683d251..abcf86352d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler): 403, "You must be in the room to create an alias for it" ) - if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + if not await self.spam_checker.user_may_create_room_alias( + user_id, room_alias + ): raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( @@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_publish_room(user_id, room_id): + if not await self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( 403, "This user is not permitted to publish rooms to the room list" ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index df82e60b33..fd8de8696d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( event.sender, event.state_key, event.room_id ): raise SynapseError( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 96843338ae..2b8aa9443d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -744,7 +744,7 @@ class EventCreationHandler: event.sender, ) - spam_error = self.spam_checker.check_event_for_spam(event) + spam_error = await self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 153cbae7b9..e850e45e46 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,7 +18,6 @@ from typing import List, Tuple from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler from synapse.types import JsonDict, ReadReceipt, get_domain_from_id -from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - await maybe_awaitable( - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + await self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids ) return True diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0d85fd0868..94b5610acd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) - result = self.spam_checker.check_registration_for_spam( + result = await self.spam_checker.check_registration_for_spam( threepid, localpart, user_agent_ips or [], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 82fb72b381..7583418946 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_create_room(user_id): + if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) - if not is_requester_admin and not self.spam_checker.user_may_create_room( + if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id ): raise SynapseError(403, "You are not permitted to create rooms") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d85110a35e..cb5a29bc7e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) block_invite = True - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index afbebfc200..f263a638f8 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler): results = await self.store.search_user_dir(user_id, search_term, limit) # Remove any spammy users from the results. - results["results"] = [ - user - for user in results["results"] - if not self.spam_checker.check_username_for_spam(user) - ] + non_spammy_users = [] + for user in results["results"]: + if not await self.spam_checker.check_username_for_spam(user): + non_spammy_users.append(user) + results["results"] = non_spammy_users return results diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 658f6ecd72..76b7decf26 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import threading from functools import wraps @@ -25,6 +24,7 @@ from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.opentracing import noop_context_manager, start_active_span +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar if bg_start_span: ctx = start_active_span(desc, tags={"request_id": context.request}) with ctx: - result = func(*args, **kwargs) - - if inspect.isawaitable(result): - result = await result - - return result + return await maybe_awaitable(func(*args, **kwargs)) except Exception: logger.exception( "Background process '%s' threw an exception", desc, diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 18c9ed48d6..67f67efde7 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os import shutil @@ -21,6 +20,7 @@ from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder from .media_storage import FileResponder @@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.store_file(path, file_info)) else: # TODO: Handle errors. async def store(): try: - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) except Exception: logger.exception("Error storing file") @@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider): async def fetch(self, path, file_info): # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.fetch(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.fetch(path, file_info)) class FileStorageProviderBackend(StorageProvider): diff --git a/synapse/server.py b/synapse/server.py index 043810ad31..a198b0eb46 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta): return StatsHandler(self) @cache_in_self - def get_spam_checker(self): + def get_spam_checker(self) -> SpamChecker: return SpamChecker(self) @cache_in_self diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 382f0cf3f0..9a873c8e8e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -15,10 +15,12 @@ # limitations under the License. import collections +import inspect import logging from contextlib import contextmanager from typing import ( Any, + Awaitable, Callable, Dict, Hashable, @@ -542,11 +544,11 @@ class DoneAwaitable: raise StopIteration(self.value) -def maybe_awaitable(value): +def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable. """ - - if hasattr(value, "__await__"): + if inspect.isawaitable(value): + assert isinstance(value, Awaitable) return value return DoneAwaitable(value) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index f73e95393c..a6ee9edaec 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -105,10 +105,7 @@ class Signal: async def do(observer): try: - result = observer(*args, **kwargs) - if inspect.isawaitable(result): - result = await result - return result + return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( "%s signal observer %s failed: %r", self.name, observer, e, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 98e5af2072..647a17cb90 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): spam_checker = self.hs.get_spam_checker() class AllowAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that filters all users. class BlockAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From 895e04319ba457a855207b8bb804f7360a258464 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 14 Dec 2020 11:38:50 +0000 Subject: Preparatory refactoring of the OidcHandlerTestCase (#8911) * Remove references to handler._auth_handler (and replace them with hs.get_auth_handler) * Factor out a utility function for building Requests * Remove mocks of `OidcHandler._map_userinfo_to_user` This method is going away, so mocking it out is no longer a valid approach. Instead, we mock out lower-level methods (eg _remote_id_from_userinfo), or simply allow the regular implementation to proceed and update the expectations accordingly. * Remove references to `OidcHandler._map_userinfo_to_user` from tests This method is going away, so we can no longer use it as a test point. Instead we build mock "callback" requests which we pass into `handle_oidc_callback`, and verify correct behaviour by mocking out `AuthHandler.complete_sso_login`. --- changelog.d/8911.feature | 1 + tests/handlers/test_oidc.py | 286 ++++++++++++++++++++++---------------------- 2 files changed, 146 insertions(+), 141 deletions(-) create mode 100644 changelog.d/8911.feature (limited to 'tests') diff --git a/changelog.d/8911.feature b/changelog.d/8911.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8911.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1d99a45436..9878527bab 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -15,7 +15,7 @@ import json from urllib.parse import parse_qs, urlparse -from mock import Mock, patch +from mock import ANY, Mock, patch import pymacaroons @@ -82,7 +82,7 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None): +def simple_async_mock(return_value=None, raises=None) -> Mock: # AsyncMock is not available in python3.5, this mimics part of its behaviour async def cb(*args, **kwargs): if raises: @@ -160,6 +160,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args[2], error_description) # Reset the render_error mock self.render_error.reset_mock() + return args def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" @@ -374,26 +375,17 @@ class OidcHandlerTestCase(HomeserverTestCase): "id_token": "id_token", "access_token": "access_token", } + username = "bar" userinfo = { "sub": "foo", - "preferred_username": "bar", + "username": username, } - user_id = "@foo:domain.org" + expected_user_id = "@%s:%s" % (username, self.hs.hostname) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() code = "code" state = "state" @@ -401,64 +393,54 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - request.getCookie.return_value = self.handler._generate_oidc_session_token( + session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - - request.args = {} - request.args[b"code"] = [code.encode("utf-8")] - request.args[b"state"] = [state.encode("utf-8")] - - request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent + request = self._build_callback_request( + code, state, session, user_agent=user_agent, ip_address=ip_address + ) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) self.handler._fetch_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors - self.handler._map_userinfo_to_user = simple_async_mock( - raises=MappingException() - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("mapping_error") - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + with patch.object( + self.handler, + "_remote_id_from_userinfo", + new=Mock(side_effect=MappingException()), + ): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") # Handle ID token errors self.handler._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - self.handler._auth_handler.complete_sso_login.reset_mock() + auth_handler.complete_sso_login.reset_mock() self.handler._exchange_code.reset_mock() self.handler._parse_id_token.reset_mock() - self.handler._map_userinfo_to_user.reset_mock() self.handler._fetch_userinfo.reset_mock() # With userinfo fetching self.handler._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) self.handler._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() @@ -609,72 +591,55 @@ class OidcHandlerTestCase(HomeserverTestCase): } userinfo = { "sub": "foo", + "username": "foo", "phone": "1234567", } - user_id = "@foo:domain.org" self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() state = "state" client_redirect_url = "http://client/redirect" - request.getCookie.return_value = self.handler._generate_oidc_session_token( + session = self.handler._generate_oidc_session_token( state=state, nonce="nonce", client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - - request.args = {} - request.args[b"code"] = [b"code"] - request.args[b"state"] = [state.encode("utf-8")] - - request.getClientIP.return_value = "10.0.0.1" - request.get_user_agent.return_value = "Browser" + request = self._build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {"phone": "1234567"}, + auth_handler.complete_sso_login.assert_called_once_with( + "@foo:test", request, client_redirect_url, {"phone": "1234567"}, ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + userinfo = { "sub": "test_user", "username": "test_user", } - # The token doesn't matter with the default user mapping provider. - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", ANY, ANY, {} ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user_2:test", ANY, ANY, {} ) - self.assertEqual(mxid, "@test_user_2:test") + auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken store = self.hs.get_datastore() @@ -683,14 +648,11 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) - self.assertEqual( - str(e.value), "Mapping provider does not support de-duplicating Matrix IDs", + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", + "Mapping provider does not support de-duplicating Matrix IDs", ) @override_config({"oidc_config": {"allow_existing_users": True}}) @@ -702,26 +664,26 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, {}, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, {}, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -732,13 +694,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, {}, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -755,14 +715,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_not_called() + args = self.assertRenderedError("mapping_error") self.assertTrue( - str(e.value).startswith( + args[2].startswith( "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" ) ) @@ -773,28 +730,15 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + "@TEST_USER_2:test", ANY, ANY, {}, ) - self.assertEqual(mxid, "@TEST_USER_2:test") def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" - userinfo = { - "sub": "test2", - "username": "föö", - } - token = {} - - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + self._make_callback_with_userinfo({"sub": "test2", "username": "föö"}) + self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( { @@ -807,6 +751,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -815,14 +762,13 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) - ) + self._make_callback_with_userinfo(userinfo) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", ANY, ANY, {}, + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -838,12 +784,70 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + + def _make_callback_with_userinfo( + self, userinfo: dict, client_redirect_url: str = "http://client/redirect" + ) -> None: + self.handler._exchange_code = simple_async_mock(return_value={}) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + state = "state" + session = self.handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + request = self._build_callback_request("code", state, session) + + self.get_success(self.handler.handle_oidc_callback(request)) + + def _build_callback_request( + self, + code: str, + state: str, + session: str, + user_agent: str = "Browser", + ip_address: str = "10.0.0.1", + ): + """Builds a fake SynapseRequest to mock the browser callback + + Returns a Mock object which looks like the SynapseRequest we get from a browser + after SSO (before we return to the client) + + Args: + code: the authorization code which would have been returned by the OIDC + provider + state: the "state" param which would have been passed around in the + query param. Should be the same as was embedded in the session in + _build_oidc_session. + session: the "session" which would have been passed around in the cookie. + user_agent: the user-agent to present + ip_address: the IP address to pretend the request came from + """ + request = Mock( + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) + + request.getCookie.return_value = session + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent + return request -- cgit 1.5.1 From 1619802228033455ff6e5863c52556996b38e8c6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 14 Dec 2020 14:19:47 -0500 Subject: Various clean-ups to the logging context code (#8935) --- changelog.d/8916.misc | 2 +- changelog.d/8935.misc | 1 + synapse/config/logger.py | 2 +- synapse/http/site.py | 3 +-- synapse/logging/context.py | 24 +++++------------------- synapse/metrics/background_process_metrics.py | 7 +++---- synapse/replication/tcp/protocol.py | 3 +-- tests/handlers/test_federation.py | 6 +++--- tests/logging/test_terse_json.py | 7 ++----- tests/test_federation.py | 2 +- tests/test_utils/logging_setup.py | 2 +- 11 files changed, 20 insertions(+), 39 deletions(-) create mode 100644 changelog.d/8935.misc (limited to 'tests') diff --git a/changelog.d/8916.misc b/changelog.d/8916.misc index c71ef480e6..bf94135fd5 100644 --- a/changelog.d/8916.misc +++ b/changelog.d/8916.misc @@ -1 +1 @@ -Improve structured logging tests. +Various clean-ups to the structured logging and logging context code. diff --git a/changelog.d/8935.misc b/changelog.d/8935.misc new file mode 100644 index 0000000000..bf94135fd5 --- /dev/null +++ b/changelog.d/8935.misc @@ -0,0 +1 @@ +Various clean-ups to the structured logging and logging context code. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index d4e887a3e0..4df3f93c1c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_context_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter() log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() diff --git a/synapse/http/site.py b/synapse/http/site.py index 5f0581dc3f..5a5790831b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -128,8 +128,7 @@ class SynapseRequest(Request): # create a LogContext for this request request_id = self.get_request_id() - logcontext = self.logcontext = LoggingContext(request_id) - logcontext.request = request_id + self.logcontext = LoggingContext(request_id, request=request_id) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index ca0c774cc5..a507a83e93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -203,10 +203,6 @@ class _Sentinel: def copy_to(self, record): pass - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - def start(self, rusage: "Optional[resource._RUsage]"): pass @@ -372,13 +368,6 @@ class LoggingContext: # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record) -> None: - """ - Copy logging fields from this context to a Twisted log record. - """ - record["request"] = self.request - record["scope"] = self.scope - def start(self, rusage: "Optional[resource._RUsage]") -> None: """ Record that this logcontext is currently running. @@ -542,13 +531,10 @@ class LoggingContext: class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields """ - def __init__(self, **defaults) -> None: - self.defaults = defaults + def __init__(self, request: str = ""): + self._default_request = request def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. @@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - context.copy_to(record) + # Logging is interested in the request. + record.request = context.request return True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 76b7decf26..70e0fa45d9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc) as context: - context.request = "%s-%i" % (desc, count) + with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: try: ctx = noop_context_manager() if bg_start_span: @@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str): - super().__init__(name) + def __init__(self, name: str, request: Optional[str] = None): + super().__init__(name, request=request) self._proc = _BackgroundProcess(name, self) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index a509e599c2..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d0452e1490..0b24b89a2e 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase): # the auth code requires that a signature exists, but doesn't check that # signature... go figure. join_event.signatures[other_server] = {"x": "y"} - with LoggingContext(request="send_join"): + with LoggingContext("send_join"): d = run_in_background( self.handler.on_send_join_request, other_server, join_event ) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index f6e7e5fdaa..48a74e2eee 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): """ handler = logging.StreamHandler(self.output) handler.setFormatter(JsonFormatter()) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) logger = self.get_logger(handler) - with LoggingContext() as context_one: - context_one.request = "test" + with LoggingContext(request="test"): logger.info("Hello there, %s!", "wally") log = self.get_log_line() @@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): "level", "namespace", "request", - "scope", ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "test") - self.assertIsNone(log["scope"]) diff --git a/tests/test_federation.py b/tests/test_federation.py index fa45f8b3b7..fc9aab32d0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(request="lying_event"): + with LoggingContext(): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index fdfb840b62..52ae5c5713 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -48,7 +48,7 @@ def setup_logging(): handler = ToTwistedHandler() formatter = logging.Formatter(log_format) handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) root_logger.addHandler(handler) log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") -- cgit 1.5.1 From 6d02eb22dfde9551c515acaf73503e2500e00eaf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 14 Dec 2020 20:42:03 +0000 Subject: Fix startup failure with localdb_enabled: False (#8937) --- changelog.d/8937.bugfix | 1 + synapse/handlers/auth.py | 26 ++++++++++++-------------- tests/handlers/test_password_providers.py | 23 +++++++++++++++++++++++ 3 files changed, 36 insertions(+), 14 deletions(-) create mode 100644 changelog.d/8937.bugfix (limited to 'tests') diff --git a/changelog.d/8937.bugfix b/changelog.d/8937.bugfix new file mode 100644 index 0000000000..01e1848448 --- /dev/null +++ b/changelog.d/8937.bugfix @@ -0,0 +1 @@ +Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8deec4cd0c..21e568f226 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,27 +198,25 @@ class AuthHandler(BaseHandler): self._password_enabled = hs.config.password_enabled self._password_localdb_enabled = hs.config.password_localdb_enabled - # we keep this as a list despite the O(N^2) implication so that we can - # keep PASSWORD first and avoid confusing clients which pick the first - # type in the list. (NB that the spec doesn't require us to do so and - # clients which favour types that they don't understand over those that - # they do are technically broken) - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = [] + login_types = set() if self._password_localdb_enabled: - login_types.append(LoginType.PASSWORD) + login_types.add(LoginType.PASSWORD) for provider in self.password_providers: - if hasattr(provider, "get_supported_login_types"): - for t in provider.get_supported_login_types().keys(): - if t not in login_types: - login_types.append(t) + login_types.update(provider.get_supported_login_types().keys()) if not self._password_enabled: + login_types.discard(LoginType.PASSWORD) + + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + self._supported_login_types = [] + if LoginType.PASSWORD in login_types: + self._supported_login_types.append(LoginType.PASSWORD) login_types.remove(LoginType.PASSWORD) - - self._supported_login_types = login_types + self._supported_login_types.extend(login_types) # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index ceaf0902d2..8d50265145 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -430,6 +430,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **providers_config(CustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled(self): + """Check the localdb_enabled == enabled == False + + Regression test for https://github.com/matrix-org/synapse/issues/8914: check + that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't + cause an exception. + """ + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + @override_config( { **providers_config(PasswordCustomAuthProvider), -- cgit 1.5.1 From 01333681bc3db22541b49c194f5121a5415731c6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 15 Dec 2020 20:56:10 +0000 Subject: Preparatory refactoring of the SamlHandlerTestCase (#8938) * move simple_async_mock to test_utils ... so that it can be re-used * Remove references to `SamlHandler._map_saml_response_to_user` from tests This method is going away, so we can no longer use it as a test point. Instead, factor out a higher-level method which takes a SAML object, and verify correct behaviour by mocking out `AuthHandler.complete_sso_login`. * changelog --- changelog.d/8938.feature | 1 + synapse/handlers/saml_handler.py | 23 +++++++ tests/handlers/test_oidc.py | 12 +--- tests/handlers/test_saml.py | 132 ++++++++++++++++++++++++++------------- tests/test_utils/__init__.py | 12 ++++ 5 files changed, 126 insertions(+), 54 deletions(-) create mode 100644 changelog.d/8938.feature (limited to 'tests') diff --git a/changelog.d/8938.feature b/changelog.d/8938.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8938.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index f2ca1ddb53..6001fe3e27 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -163,6 +163,29 @@ class SamlHandler(BaseHandler): return logger.debug("SAML2 response: %s", saml2_auth.origxml) + + await self._handle_authn_response(request, saml2_auth, relay_state) + + async def _handle_authn_response( + self, + request: SynapseRequest, + saml2_auth: saml2.response.AuthnResponse, + relay_state: str, + ) -> None: + """Handle an AuthnResponse, having parsed it from the request params + + Assumes that the signature on the response object has been checked. Maps + the user onto an MXID, registering them if necessary, and returns a response + to the browser. + + Args: + request: the incoming request from the browser. We'll respond to it with an + HTML page or a redirect + saml2_auth: the parsed AuthnResponse object + relay_state: the RelayState query param, which encodes the URI to rediret + back to + """ + for assertion in saml2_auth.assertions: # kibana limits the length of a log field, whereas this is all rather # useful, so split it up. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9878527bab..464e569ac8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException from synapse.types import UserID -from tests.test_utils import FakeResponse +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - async def get_json(url): # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index d21e5588ca..69927cf6be 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from mock import Mock + import attr from synapse.api.errors import RedirectException -from synapse.handlers.sso import MappingException +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -44,6 +48,8 @@ BASE_URL = "https://synapse/" @attr.s class FakeAuthnResponse: ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) class TestMappingProvider: @@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase): def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - # The redirect_url doesn't matter with the default user mapping provider. - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri" ) - self.assertEqual(mxid, "@test_user:test") @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): @@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase): store.register_user(user_id="@test_user:test", password_hash=None) ) + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. saml_response = FakeAuthnResponse( {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} ) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - redirect_url = "" - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + auth_handler.complete_sso_login.assert_not_called() def test_map_saml_response_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) + + # send the fake SAML response saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "" + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular SAML username. self.get_success( @@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase): # Now attempt to map to a username, this will fail since all potential usernames are taken. saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", ) + auth_handler.complete_sso_login.assert_not_called() @override_config( { @@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase): } ) def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" + request = _mock_request() e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), + self.handler._handle_authn_response(request, saml_response, ""), RedirectException, ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 6873d45eb6..43898d8142 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -22,6 +22,8 @@ import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar +from mock import Mock + import attr from twisted.python.failure import Failure @@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]: return cleanup +def simple_async_mock(return_value=None, raises=None) -> Mock: + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + @attr.s class FakeResponse: """A fake twisted.web.IResponse object -- cgit 1.5.1 From 7eebe4b3fc3129e4571d58c3cea5eeccc584e072 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 14:51:25 +0000 Subject: Replace `request.code` with `channel.code` The two are equivalent, but really we want to check the HTTP result that got returned to the channel, not the code that the Request object *intended* to return to the channel. --- tests/http/test_additional_resource.py | 4 ++-- tests/replication/test_auth.py | 10 +++++----- tests/replication/test_client_reader_shard.py | 8 ++++---- tests/rest/client/v2_alpha/test_account.py | 4 ++-- tests/rest/client/v2_alpha/test_auth.py | 10 +++++----- tests/rest/client/v2_alpha/test_register.py | 2 +- tests/rest/test_health.py | 2 +- tests/rest/test_well_known.py | 4 ++-- 8 files changed, 22 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 05e9c449be..adc318caad 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -48,7 +48,7 @@ class AdditionalResourceTests(HomeserverTestCase): request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) def test_sync(self): @@ -57,5 +57,5 @@ class AdditionalResourceTests(HomeserverTestCase): request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index fe9e4d5f9a..aee839dc69 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -66,7 +66,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] @@ -84,7 +84,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): """With no authentication the request should finish. """ request, channel = self._test_register() - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -94,7 +94,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): """If the main process expects a secret that is not provided, an error results. """ request, channel = self._test_register() - self.assertEqual(request.code, 500) + self.assertEqual(channel.code, 500) @override_config( { @@ -106,14 +106,14 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): """If the main process receives the wrong secret, an error results. """ request, channel = self._test_register() - self.assertEqual(request.code, 500) + self.assertEqual(channel.code, 500) @override_config({"worker_replication_secret": "my-secret"}) def test_authorized(self): """The request should finish when the worker provides the authentication header. """ request, channel = self._test_register() - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index fdaad3d8ad..6cdf6a099b 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -48,7 +48,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] @@ -61,7 +61,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_2.code, 200) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") @@ -80,7 +80,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] @@ -94,7 +94,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): "register", {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_2.code, 200) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 2ac1ecb7d3..11a042f9e9 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -353,7 +353,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): # Check that this access token has been invalidated. request, channel = self.make_request("GET", "account/whoami") - self.assertEqual(request.code, 401) + self.assertEqual(channel.code, 401) def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" @@ -410,7 +410,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index ac67a9de29..a35c2646fb 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -71,7 +71,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): "POST", "register", body ) # type: SynapseRequest, FakeChannel - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel def recaptcha( @@ -84,7 +84,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) # type: SynapseRequest, FakeChannel - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) request, channel = self.make_request( "POST", @@ -92,7 +92,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): + post_session + "&g-recaptcha-response=a", ) - self.assertEqual(request.code, expected_post_response) + self.assertEqual(channel.code, expected_post_response) # The recaptcha handler is called with the response given attempts = self.recaptcha_checker.recaptcha_attempts @@ -201,7 +201,7 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # type: SynapseRequest, FakeChannel # Ensure the response is sane. - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel @@ -214,7 +214,7 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # type: SynapseRequest, FakeChannel # Ensure the response is sane. - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index bcb21d0ced..4bf3e0d63b 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -559,7 +559,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 02a46e5fda..aaf2fb821b 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -27,5 +27,5 @@ class HealthCheckTests(unittest.HomeserverTestCase): def test_health(self): request, channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 6a930f4148..17ded96b9c 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -32,7 +32,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, { @@ -48,4 +48,4 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(request.code, 404) + self.assertEqual(channel.code, 404) -- cgit 1.5.1 From 0378581c1335e7e7e9dc0b482b5bf6efb63859be Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 22:34:20 +0000 Subject: remove 'response' result from `_get_shared_rooms` --- tests/rest/client/v2_alpha/test_shared_rooms.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index 562a9c1ba4..05c5ee5a78 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import shared_rooms from tests import unittest +from tests.server import FakeChannel class UserSharedRoomsTest(unittest.HomeserverTestCase): @@ -40,14 +41,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.store = hs.get_datastore() self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user): - request, channel = self.make_request( + def _get_shared_rooms(self, token, other_user) -> FakeChannel: + _, channel = self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" % other_user, access_token=token, ) - return request, channel + return channel def test_shared_room_list_public(self): """ @@ -63,7 +64,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) @@ -82,7 +83,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) @@ -104,7 +105,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room_public, user=u2, tok=u2_token) self.helper.join(room_private, user=u1, tok=u1_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 2) self.assertTrue(room_public in channel.json_body["joined"]) @@ -125,13 +126,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Assert user directory is not empty - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) - request, channel = self._get_shared_rooms(u2_token, u1) + channel = self._get_shared_rooms(u2_token, u1) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 0) -- cgit 1.5.1 From 5bcf6e8289b075ef298510faae12041811bbd344 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 14:54:41 +0000 Subject: Skip redundant check on `request.args` --- tests/test_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_server.py b/tests/test_server.py index 6b2d2f0401..0be8c7e2fc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -64,11 +64,10 @@ class JsonResourceTests(unittest.TestCase): "test_servlet", ) - request, channel = make_request( + make_request( self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) - self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self): -- cgit 1.5.1 From ac2acf1524ea33a37b5cdbd4cebcb149c80a48c7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 15:18:03 +0000 Subject: Remove redundant reading of SynapseRequest.args this didn't seem to be doing a lot, so remove it. --- tests/rest/client/v2_alpha/test_account.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 11a042f9e9..8a898be24c 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -19,7 +19,6 @@ import os import re from email.parser import Parser from typing import Optional -from urllib.parse import urlencode import pkg_resources @@ -268,20 +267,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button - # Send arguments as url-encoded form data, matching the template's behaviour - form_args = [] - for key, value_list in request.args.items(): - for value in value_list: - arg = (key, value) - form_args.append(arg) - # Confirm the password reset request, channel = make_request( self.reactor, FakeSite(self.submit_token_resource), "POST", path, - content=urlencode(form_args).encode("utf8"), + content=b"", shorthand=False, content_is_form=True, ) -- cgit 1.5.1 From 394516ad1bb6127ab5b32a12d81ef307deb39570 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 14:44:04 +0000 Subject: Remove spurious "SynapseRequest" result from `make_request" This was never used, so let's get rid of it. --- tests/app/test_frontend_proxy.py | 4 +- tests/app/test_openid_listener.py | 4 +- tests/federation/test_complexity.py | 4 +- tests/federation/test_federation_server.py | 6 +- tests/federation/transport/test_server.py | 8 +- tests/handlers/test_directory.py | 12 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_password_providers.py | 6 +- tests/handlers/test_typing.py | 2 +- tests/handlers/test_user_directory.py | 4 +- tests/http/test_additional_resource.py | 4 +- tests/push/test_http.py | 2 +- tests/replication/test_auth.py | 14 +- tests/replication/test_client_reader_shard.py | 16 +- tests/replication/test_multi_media_repo.py | 2 +- tests/replication/test_sharded_event_persister.py | 16 +- tests/rest/admin/test_admin.py | 32 +-- tests/rest/admin/test_device.py | 114 +++----- tests/rest/admin/test_event_reports.py | 62 ++-- tests/rest/admin/test_media.py | 62 ++-- tests/rest/admin/test_room.py | 98 +++---- tests/rest/admin/test_statistics.py | 60 ++-- tests/rest/admin/test_user.py | 314 ++++++++------------- tests/rest/client/test_consent.py | 8 +- tests/rest/client/test_ephemeral_message.py | 2 +- tests/rest/client/test_identity.py | 6 +- tests/rest/client/test_redactions.py | 8 +- tests/rest/client/test_retention.py | 2 +- tests/rest/client/test_shadow_banned.py | 16 +- tests/rest/client/test_third_party_rules.py | 10 +- tests/rest/client/v1/test_directory.py | 8 +- tests/rest/client/v1/test_events.py | 8 +- tests/rest/client/v1/test_login.py | 84 +++--- tests/rest/client/v1/test_presence.py | 4 +- tests/rest/client/v1/test_profile.py | 16 +- tests/rest/client/v1/test_push_rule_attrs.py | 66 ++--- tests/rest/client/v1/test_rooms.py | 224 +++++++-------- tests/rest/client/v1/test_typing.py | 8 +- tests/rest/client/v1/utils.py | 18 +- tests/rest/client/v2_alpha/test_account.py | 44 ++- tests/rest/client/v2_alpha/test_auth.py | 20 +- tests/rest/client/v2_alpha/test_capabilities.py | 8 +- tests/rest/client/v2_alpha/test_filter.py | 14 +- tests/rest/client/v2_alpha/test_password_policy.py | 18 +- tests/rest/client/v2_alpha/test_register.py | 68 ++--- tests/rest/client/v2_alpha/test_relations.py | 36 +-- tests/rest/client/v2_alpha/test_shared_rooms.py | 3 +- tests/rest/client/v2_alpha/test_sync.py | 32 +-- tests/rest/media/v1/test_media_storage.py | 4 +- tests/rest/media/v1/test_url_preview.py | 38 +-- tests/rest/test_health.py | 2 +- tests/rest/test_well_known.py | 4 +- tests/server.py | 8 +- tests/server_notices/test_consent.py | 6 +- .../test_resource_limits_server_notices.py | 8 +- tests/test_mau.py | 4 +- tests/test_server.py | 22 +- tests/test_terms_auth.py | 6 +- tests/unittest.py | 44 +-- 59 files changed, 742 insertions(+), 983 deletions(-) (limited to 'tests') diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 43fef5d64a..e0ca288829 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -57,7 +57,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") + channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 400 + unrecognised, because nothing is registered self.assertEqual(channel.code, 400) @@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") + channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 401, because the stub servlet still checks authentication self.assertEqual(channel.code, 401) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index b260ab734d..467033e201 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -73,7 +73,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): return raise - _, channel = make_request( + channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) @@ -121,7 +121,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): return raise - _, channel = make_request( + channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 0187f56e21..9ccd2d76b8 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -48,7 +48,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Get the room complexity - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) self.assertEquals(200, channel.code) @@ -60,7 +60,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) self.assertEquals(200, channel.code) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 3009fbb6c4..cfeccc0577 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -46,7 +46,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/get_missing_events/(?P[^/]*)/?" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, @@ -95,7 +95,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) self.inject_room_member(room_1, "@user:other.example.com", "join") - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) self.assertEquals(200, channel.code, channel.result) @@ -127,7 +127,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index f9e3c7a51f..212fb79a00 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -37,14 +37,10 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase): @override_config({"allow_public_rooms_over_federation": False}) def test_blocked_public_room_list_over_federation(self): - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/publicRooms" - ) + channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms") self.assertEquals(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/publicRooms" - ) + channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms") self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 770d225ed5..a39f898608 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -405,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): def test_denied(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), @@ -415,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), @@ -431,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.assertEquals(200, channel.code, channel.result) @@ -446,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request("GET", b"publicRooms") + channel = self.make_request("GET", b"publicRooms") self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -454,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request("GET", b"publicRooms") + channel = self.make_request("GET", b"publicRooms") self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index af42775815..f955dfa490 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): # Redaction of event should fail. path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id) - request, channel = self.make_request( + channel = self.make_request( "POST", path, content={}, access_token=self.access_token ) self.assertEqual(int(channel.result["code"]), 403) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 8d50265145..f816594ee4 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -551,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) def _get_login_flows(self) -> JsonDict: - _, channel = self.make_request("GET", "/_matrix/client/r0/login") + channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) return channel.json_body["flows"] @@ -560,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def _send_login(self, type, user, **params) -> FakeChannel: params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) - _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel def _start_delete_device_session(self, access_token, device_id) -> str: @@ -597,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", ) -> FakeChannel: """Delete an individual device.""" - _, channel = self.make_request( + channel = self.make_request( "DELETE", "devices/" + device, body, access_token=access_token ) return channel diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f21de958f1..96e5bdac4a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -220,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - (request, channel) = self.make_request( + channel = self.make_request( "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 647a17cb90..1260721dbf 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -534,7 +534,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.helper.join(room, user=u2) # Assert user directory is not empty - request, channel = self.make_request( + channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.assertEquals(200, channel.code, channel.result) @@ -542,7 +542,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False - request, channel = self.make_request( + channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.assertEquals(200, channel.code, channel.result) diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index adc318caad..453391a5a5 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -46,7 +46,7 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request(self.reactor, FakeSite(resource), "GET", "/") self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) @@ -55,7 +55,7 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request(self.reactor, FakeSite(resource), "GET", "/") self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8b4af74c51..cb3245d8cf 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -667,7 +667,7 @@ class HTTPPusherTests(HomeserverTestCase): # This will actually trigger a new notification to be sent out so that # even if the user does not receive another message, their unread # count goes down - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), {}, diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index aee839dc69..c5ab3032a5 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -47,7 +47,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): return config - def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]: + def _test_register(self) -> FakeChannel: """Run the actual test: 1. Create a worker homeserver. @@ -59,13 +59,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): worker_hs = self.make_worker_hs("synapse.app.client_reader") site = self._hs_to_site[worker_hs] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel + ) self.assertEqual(channel_1.code, 401) # Grab the session @@ -83,7 +83,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def test_no_auth(self): """With no authentication the request should finish. """ - request, channel = self._test_register() + channel = self._test_register() self.assertEqual(channel.code, 200) # We're given a registered user. @@ -93,7 +93,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def test_missing_auth(self): """If the main process expects a secret that is not provided, an error results. """ - request, channel = self._test_register() + channel = self._test_register() self.assertEqual(channel.code, 500) @override_config( @@ -105,14 +105,14 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def test_unauthorized(self): """If the main process receives the wrong secret, an error results. """ - request, channel = self._test_register() + channel = self._test_register() self.assertEqual(channel.code, 500) @override_config({"worker_replication_secret": "my-secret"}) def test_authorized(self): """The request should finish when the worker provides the authentication header. """ - request, channel = self._test_register() + channel = self._test_register() self.assertEqual(channel.code, 200) # We're given a registered user. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 6cdf6a099b..abcc74f932 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -41,26 +41,26 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs = self.make_worker_hs("synapse.app.client_reader") site = self._hs_to_site[worker_hs] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel + ) self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = make_request( + channel_2 = make_request( self.reactor, site, "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}, - ) # type: SynapseRequest, FakeChannel + ) self.assertEqual(channel_2.code, 200) # We're given a registered user. @@ -73,13 +73,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") site_1 = self._hs_to_site[worker_hs_1] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site_1, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel + ) self.assertEqual(channel_1.code, 401) # Grab the session @@ -87,13 +87,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): # also complete the dummy auth site_2 = self._hs_to_site[worker_hs_2] - request_2, channel_2 = make_request( + channel_2 = make_request( self.reactor, site_2, "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}, - ) # type: SynapseRequest, FakeChannel + ) self.assertEqual(channel_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 83afd9fd2f..d1feca961f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): the media which the caller should respond to. """ resource = hs.get_media_repository_resource().children[b"download"] - _, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 77fc3856d5..8d494ebc03 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Do an initial sync so that we're up to date. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token ) next_batch = channel.json_body["next_batch"] @@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Check that syncing still gets the new event, despite the gap in the # stream IDs. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) first_event_in_room2 = response["event_id"] - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back in the first room should not produce any results, as # no events have happened in it. This tests that we are correctly # filtering results based on the vector clock portion. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back on the second room should produce the first event # again. This tests that pagination isn't completely broken. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Paginating forwards should give the same results - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) self.assertListEqual([], channel.json_body["chunk"]) - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 67d8878395..0504cd187e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase): return resource def test_version_string(self): - request, channel = self.make_request("GET", self.url, shorthand=False) + channel = self.make_request("GET", self.url, shorthand=False) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -68,7 +68,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def test_delete_group(self): # Create a new group - request, channel = self.make_request( + channel = self.make_request( "POST", "/create_group".encode("ascii"), access_token=self.admin_user_tok, @@ -84,13 +84,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Invite/join another user url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -101,7 +101,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Now delete the group url = "/_synapse/admin/v1/delete_group/" + group_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=self.admin_user_tok, @@ -123,7 +123,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): """ url = "/groups/%s/profile" % (group_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) @@ -134,7 +134,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/joined_groups".encode("ascii"), access_token=access_token ) @@ -216,7 +216,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): def _ensure_quarantined(self, admin_user_tok, server_and_media_id): """Ensure a piece of media is quarantined when trying to access it.""" - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", @@ -241,7 +241,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt quarantine media APIs as non-admin url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) @@ -254,7 +254,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # And the roomID/userID endpoint url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) @@ -282,7 +282,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", @@ -299,7 +299,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): urllib.parse.quote(server_name), urllib.parse.quote(media_id), ) - request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request("POST", url, access_token=admin_user_tok,) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -351,7 +351,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( room_id ) - request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request("POST", url, access_token=admin_user_tok,) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( @@ -395,7 +395,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( non_admin_user ) - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) self.pump(1.0) @@ -437,7 +437,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( non_admin_user ) - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) self.pump(1.0) @@ -453,7 +453,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) # Attempt to access each piece of media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index cf3a007598..248c4442c3 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -50,17 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Try to get a device of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("PUT", self.url, b"{}") + channel = self.make_request("PUT", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("DELETE", self.url, b"{}") + channel = self.make_request("DELETE", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -69,21 +69,21 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url, access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( + channel = self.make_request( "DELETE", self.url, access_token=self.other_user_token, ) @@ -99,23 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -129,23 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -158,22 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) # Delete unknown device returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -197,7 +179,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): } body = json.dumps(update) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, @@ -208,9 +190,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -227,16 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( - "PUT", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -247,7 +223,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ # Set new display_name body = json.dumps({"display_name": "new displayname"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, @@ -257,9 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) @@ -268,9 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Tests that a normal lookup for a device is successfully """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -291,7 +263,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(1, number_devices) # Delete device - request, channel = self.make_request( + channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok, ) @@ -323,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ Try to list devices of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -334,9 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -346,9 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -359,9 +327,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -373,9 +339,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ # Get devices - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -391,9 +355,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.login("user", "pass") # Get devices - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) @@ -431,7 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ Try to delete devices of an user without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -442,9 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "POST", self.url, access_token=other_user_token, - ) + channel = self.make_request("POST", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -454,9 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" - request, channel = self.make_request( - "POST", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("POST", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -467,9 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" - request, channel = self.make_request( - "POST", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("POST", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -479,7 +435,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a remove of a device that does not exist returns 200. """ body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, @@ -510,7 +466,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Delete devices body = json.dumps({"devices": device_ids}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 11b72c10f7..aa389df12f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -74,7 +74,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ Try to get an event report without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -84,9 +84,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -96,9 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -111,7 +107,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with limit """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) @@ -126,7 +122,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a defined starting point (from) """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) @@ -141,7 +137,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a defined starting point and limit """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) @@ -156,7 +152,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of room """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?room_id=%s" % self.room_id1, access_token=self.admin_user_tok, @@ -176,7 +172,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of user """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?user_id=%s" % self.other_user, access_token=self.admin_user_tok, @@ -196,7 +192,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of user and room """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), access_token=self.admin_user_tok, @@ -218,7 +214,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ # fetch the most recent first, largest timestamp - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=b", access_token=self.admin_user_tok, ) @@ -234,7 +230,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report += 1 # fetch the oldest first, smallest timestamp - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=f", access_token=self.admin_user_tok, ) @@ -254,7 +250,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a invalid search order returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) @@ -267,7 +263,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a negative limit parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) @@ -279,7 +275,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a negative from parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) @@ -293,7 +289,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) @@ -304,7 +300,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) @@ -315,7 +311,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) @@ -327,7 +323,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # Check # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) @@ -342,7 +338,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), json.dumps({"score": -100, "reason": "this makes me sad"}), @@ -399,7 +395,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ Try to get event report without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -409,9 +405,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -421,9 +415,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): Testing get a reported event """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) @@ -434,7 +426,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ # `report_id` is negative - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/-123", access_token=self.admin_user_tok, @@ -448,7 +440,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) # `report_id` is a non-numerical string - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/abcdef", access_token=self.admin_user_tok, @@ -462,7 +454,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) # `report_id` is undefined - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/", access_token=self.admin_user_tok, @@ -480,7 +472,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): Testing that a not existing `report_id` returns a 404. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/123", access_token=self.admin_user_tok, @@ -496,7 +488,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), json.dumps({"score": -100, "reason": "this makes me sad"}), diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index dadf9db660..c2b998cdae 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -50,7 +50,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - request, channel = self.make_request("DELETE", url, b"{}") + channel = self.make_request("DELETE", url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -64,9 +64,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - request, channel = self.make_request( - "DELETE", url, access_token=self.other_user_token, - ) + channel = self.make_request("DELETE", url, access_token=self.other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -77,9 +75,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -90,9 +86,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) @@ -121,7 +115,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(server_name, self.server_name) # Attempt to access media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(download_resource), "GET", @@ -146,9 +140,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) # Delete media - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -157,7 +149,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) # Attempt to access media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(download_resource), "GET", @@ -204,7 +196,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): Try to delete media without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -216,7 +208,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.other_user_token, ) @@ -229,7 +221,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" - request, channel = self.make_request( + channel = self.make_request( "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, ) @@ -240,9 +232,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ If the parameter `before_ts` is missing, an error is returned. """ - request, channel = self.make_request( - "POST", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -254,7 +244,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ If parameters are invalid, an error is returned. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, ) @@ -265,7 +255,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=1234&size_gt=-1234", access_token=self.admin_user_tok, @@ -278,7 +268,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=1234&keep_profiles=not_bool", access_token=self.admin_user_tok, @@ -308,7 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # timestamp after upload/create now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, @@ -332,7 +322,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, @@ -344,7 +334,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # timestamp after upload now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, @@ -367,7 +357,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, @@ -378,7 +368,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, @@ -401,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) # set media as avatar - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.admin_user,), content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), @@ -410,7 +400,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, @@ -421,7 +411,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, @@ -445,7 +435,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # set media as room avatar room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.avatar" % (room_id,), content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), @@ -454,7 +444,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, @@ -465,7 +455,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, @@ -512,7 +502,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): media_id = server_and_media_id.split("/")[1] local_path = self.filepaths.local_media_filepath(media_id) - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(download_resource), "GET", diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 9c100050d2..ca20bcad08 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -79,7 +79,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/shutdown_room/" + room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -103,7 +103,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Enable world readable url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), @@ -113,7 +113,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/shutdown_room/" + room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -130,7 +130,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): """ url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -138,7 +138,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -184,7 +184,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, json.dumps({}), access_token=self.other_user_tok, ) @@ -197,7 +197,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms/!unknown:test/delete" - request, channel = self.make_request( + channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) @@ -210,7 +210,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms/invalidroom/delete" - request, channel = self.make_request( + channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) @@ -225,7 +225,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"new_room_user_id": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -244,7 +244,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"new_room_user_id": "@not:exist.bla"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -262,7 +262,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"block": "NotBool"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -278,7 +278,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"purge": "NotBool"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -304,7 +304,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": True, "purge": True}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), @@ -337,7 +337,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": True}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), @@ -371,7 +371,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": False}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), @@ -418,7 +418,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -448,7 +448,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Enable world readable url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), @@ -465,7 +465,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -530,7 +530,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -538,7 +538,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -569,7 +569,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), {"room_id": room_id}, @@ -623,7 +623,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) @@ -704,7 +704,7 @@ class RoomTestCase(unittest.HomeserverTestCase): limit, "name", ) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual( @@ -744,7 +744,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_ids, returned_room_ids) url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -764,7 +764,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), {"room_id": room_id}, @@ -794,7 +794,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -835,7 +835,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/r0/directory/room/%s" % ( urllib.parse.quote(test_alias), ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), {"room_id": room_id}, @@ -875,7 +875,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) if reverse: url += "&dir=b" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1011,7 +1011,7 @@ class RoomTestCase(unittest.HomeserverTestCase): expected_http_code: The expected http code for the request """ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -1072,7 +1072,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1102,7 +1102,7 @@ class RoomTestCase(unittest.HomeserverTestCase): room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1114,7 +1114,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.join(room_id_1, user_1, tok=user_tok_1) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1124,7 +1124,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok) self.helper.leave(room_id_1, user_1, tok=user_tok_1) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1153,7 +1153,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.join(room_id_2, user_3, tok=user_tok_3) url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1164,7 +1164,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 3) url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1204,7 +1204,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1220,7 +1220,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"unknown_parameter": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1236,7 +1236,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1252,7 +1252,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": "@not:exist.bla"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1272,7 +1272,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1289,7 +1289,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1308,7 +1308,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1320,7 +1320,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1337,7 +1337,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1367,7 +1367,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if server admin is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1378,7 +1378,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1389,7 +1389,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1406,7 +1406,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1418,7 +1418,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 907b49f889..73f8a8ec99 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -46,7 +46,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ Try to list users without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -55,7 +55,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url, json.dumps({}), access_token=self.other_user_tok, ) @@ -67,7 +67,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): If parameters are invalid, an error is returned. """ # unkown order_by - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok, ) @@ -75,7 +75,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) @@ -83,7 +83,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) @@ -91,7 +91,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok, ) @@ -99,7 +99,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok, ) @@ -107,7 +107,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=10&until_ts=5", access_token=self.admin_user_tok, @@ -117,7 +117,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=", access_token=self.admin_user_tok, ) @@ -125,7 +125,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) @@ -138,7 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ self._create_users_with_media(10, 2) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) @@ -154,7 +154,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ self._create_users_with_media(20, 2) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) @@ -170,7 +170,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ self._create_users_with_media(20, 2) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) @@ -190,7 +190,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) @@ -201,7 +201,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) @@ -212,7 +212,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) @@ -223,7 +223,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # Set `from` to value of `next_token` for request remaining entries # Check `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) @@ -238,9 +238,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): if users have no media created """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -316,15 +314,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ts1 = self.clock.time_msec() # list all media when filter is not set - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media # result is 0 - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -337,7 +333,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_media(self.other_user_tok, 3) # filter media between `ts1` and `ts2` - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, @@ -346,7 +342,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -356,14 +352,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_users_with_media(20, 1) # check without filter get all users - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, @@ -372,7 +366,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, @@ -382,7 +376,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 1) # filter and get empty result - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -447,7 +441,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url = self.url + "?order_by=%s" % (order_type,) if dir is not None and dir in ("b", "f"): url += "&dir=%s" % (dir,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ba1438cdc7..582f983225 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -70,7 +70,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -87,7 +87,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.get_secrets = Mock(return_value=secrets) - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) self.assertEqual(channel.json_body, {"nonce": "abcd"}) @@ -96,14 +96,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] # 59 seconds self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -111,7 +111,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 61 seconds self.reactor.advance(2) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -120,7 +120,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ Only the provided nonce can be used, as it's checked in the MAC. """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -136,7 +136,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) @@ -146,7 +146,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): When the correct nonce is provided, and the right key is provided, the user is registered. """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -165,7 +165,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -174,7 +174,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ A valid unrecognised nonce. """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -190,13 +190,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -209,7 +209,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ def nonce(): - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) return channel.json_body["nonce"] # @@ -218,7 +218,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -229,28 +229,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -261,28 +261,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -300,7 +300,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "user_type": "invalid", } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid user type", channel.json_body["error"]) @@ -311,7 +311,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ # set no displayname - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -321,17 +321,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps( {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob1:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob1:test/displayname") + channel = self.make_request("GET", "/profile/@bob1:test/displayname") self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -347,17 +347,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob2:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob2:test/displayname") + channel = self.make_request("GET", "/profile/@bob2:test/displayname") self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -373,16 +373,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob3:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob3:test/displayname") + channel = self.make_request("GET", "/profile/@bob3:test/displayname") self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) # set displayname - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -398,12 +398,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob4:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob4:test/displayname") + channel = self.make_request("GET", "/profile/@bob4:test/displayname") self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @@ -429,7 +429,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) # Register new user with admin API - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -448,7 +448,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -473,7 +473,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ Try to list users without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) @@ -482,7 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?deactivated=true", b"{}", @@ -520,14 +520,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@bob:test" - request, channel = self.make_request( - "GET", url, access_token=self.other_user_token, - ) + channel = self.make_request("GET", url, access_token=self.other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.other_user_token, content=b"{}", ) @@ -539,7 +537,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v2/users/@unknown_person:test", access_token=self.admin_user_tok, @@ -565,7 +563,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -581,9 +579,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -612,7 +608,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -628,9 +624,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -656,9 +650,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Sync to set admin user to active # before limit of monthly active users is reached - request, channel = self.make_request( - "GET", "/sync", access_token=self.admin_user_tok - ) + channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) if channel.code != 200: raise HttpResponseException( @@ -681,7 +673,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123", "admin": False}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -720,7 +712,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123", "admin": False}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -757,7 +749,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -801,7 +793,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -827,7 +819,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Change password body = json.dumps({"password": "hahaha"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -844,7 +836,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Modify user body = json.dumps({"displayname": "foobar"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -856,7 +848,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("foobar", channel.json_body["displayname"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -874,7 +866,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -887,7 +879,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -904,7 +896,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Deactivate user body = json.dumps({"deactivated": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -917,7 +909,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # the user is deactivated, the threepid will be deleted # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -931,7 +923,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Deactivate the user. - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -944,7 +936,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", True) # Attempt to reactivate the user (without a password). - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -953,7 +945,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Reactivate the user. - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -964,7 +956,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -981,7 +973,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set a user as an admin body = json.dumps({"admin": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -993,7 +985,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(True, channel.json_body["admin"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -1011,7 +1003,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -1023,9 +1015,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("bob", channel.json_body["displayname"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1035,7 +1025,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Change password (and use a str for deactivate instead of a bool) body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -1045,9 +1035,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Check user is not deactivated - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1089,7 +1077,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ Try to list rooms of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1100,9 +1088,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1112,9 +1098,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1125,9 +1109,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1138,9 +1120,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): if user has no memberships """ # Get rooms - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1157,9 +1137,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) @@ -1188,7 +1166,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ Try to list pushers of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1199,9 +1177,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1211,9 +1187,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1224,9 +1198,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1237,9 +1209,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ # Get pushers - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1266,9 +1236,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ) # Get pushers - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1307,7 +1275,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ Try to list media of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1318,9 +1286,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1330,9 +1296,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/media" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1343,9 +1307,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1359,7 +1321,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) @@ -1378,7 +1340,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) @@ -1397,7 +1359,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) @@ -1412,7 +1374,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Testing that a negative limit parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) @@ -1424,7 +1386,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Testing that a negative from parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) @@ -1442,7 +1404,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) @@ -1453,7 +1415,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) @@ -1464,7 +1426,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) @@ -1476,7 +1438,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Check # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) @@ -1491,9 +1453,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): if user has no media created """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1508,9 +1468,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) @@ -1576,7 +1534,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ) def _get_token(self) -> str: - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1585,7 +1543,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): def test_no_auth(self): """Try to login as a user without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1593,7 +1551,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): def test_not_admin(self): """Try to login as a user as a non-admin user. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, b"{}", access_token=self.other_user_tok ) @@ -1621,7 +1579,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self._get_token() # Check that we don't see a new device in our devices list - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1636,25 +1594,19 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): puppet_token = self._get_token() # Test that we can successfully make a request - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Logout with the puppet token - request, channel = self.make_request( - "POST", "logout", b"{}", access_token=puppet_token - ) + channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # The puppet token should no longer work - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) # .. but the real user's tokens should still work - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1667,25 +1619,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): puppet_token = self._get_token() # Test that we can successfully make a request - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Logout all with the real user token - request, channel = self.make_request( + channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # The puppet token should still work - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # .. but the real user's tokens shouldn't - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) @@ -1698,25 +1646,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): puppet_token = self._get_token() # Test that we can successfully make a request - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Logout all with the admin user token - request, channel = self.make_request( + channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # The puppet token should no longer work - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) # .. but the real user's tokens should still work - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1798,11 +1742,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ Try to get information of an user without authentication. """ - request, channel = self.make_request("GET", self.url1, b"{}") + channel = self.make_request("GET", self.url1, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("GET", self.url2, b"{}") + channel = self.make_request("GET", self.url2, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1813,15 +1757,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.register_user("user2", "pass") other_user2_token = self.login("user2", "pass") - request, channel = self.make_request( - "GET", self.url1, access_token=other_user2_token, - ) + channel = self.make_request("GET", self.url1, access_token=other_user2_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( - "GET", self.url2, access_token=other_user2_token, - ) + channel = self.make_request("GET", self.url2, access_token=other_user2_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1832,15 +1772,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - request, channel = self.make_request( - "GET", url1, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url1, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) - request, channel = self.make_request( - "GET", url2, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url2, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) @@ -1848,16 +1784,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ The lookup should succeed for an admin. """ - request, channel = self.make_request( - "GET", self.url1, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - request, channel = self.make_request( - "GET", self.url2, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -1868,16 +1800,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url1, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url1, access_token=other_user_token,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - request, channel = self.make_request( - "GET", self.url2, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url2, access_token=other_user_token,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index e2e6a5e16d..c74693e9b2 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -61,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def test_render_public_consent(self): """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False ) self.assertEqual(channel.code, 200) @@ -82,7 +82,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") + "&u=user" ) - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", @@ -97,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "POST", @@ -109,7 +109,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # Fetch the consent page, to get the consent version -- it should have # changed - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index a1ccc4ee9a..56937dcd2e 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -93,7 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) - request, channel = self.make_request("GET", url) + channel = self.make_request("GET", url) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 259c6a1985..c0a9fc6925 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -43,9 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") - request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok - ) + channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) room_id = channel.json_body["room_id"] @@ -56,7 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") - request, channel = self.make_request( + channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index c1f516cc93..f0707646bb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -69,16 +69,12 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - request, channel = self.make_request( - "POST", path, content={}, access_token=access_token - ) + channel = self.make_request("POST", path, content={}, access_token=access_token) self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body def _sync_room_timeline(self, access_token, room_id): - request, channel = self.make_request( - "GET", "sync", access_token=self.mod_access_token - ) + channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index f56b5d9231..31dc832fd5 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -325,7 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) - request, channel = self.make_request("GET", url, access_token=self.token) + channel = self.make_request("GET", url, access_token=self.token) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 94dcfb9f7c..e689c3fbea 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -89,7 +89,7 @@ class RoomTestCase(_ShadowBannedBase): ) # Inviting the user completes successfully. - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/invite" % (room_id,), {"id_server": "test", "medium": "email", "address": "test@test.test"}, @@ -103,7 +103,7 @@ class RoomTestCase(_ShadowBannedBase): def test_create_room(self): """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/createRoom", {"visibility": "public", "invite": [self.other_user_id]}, @@ -158,7 +158,7 @@ class RoomTestCase(_ShadowBannedBase): self.banned_user_id, tok=self.banned_access_token ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), {"new_version": "6"}, @@ -183,7 +183,7 @@ class RoomTestCase(_ShadowBannedBase): self.banned_user_id, tok=self.banned_access_token ) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), {"typing": True, "timeout": 30000}, @@ -198,7 +198,7 @@ class RoomTestCase(_ShadowBannedBase): # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (room_id, self.other_user_id), {"typing": True, "timeout": 30000}, @@ -244,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase): ) # The update should succeed. - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), {"displayname": new_display_name}, @@ -254,7 +254,7 @@ class ProfileTestCase(_ShadowBannedBase): self.assertEqual(channel.json_body, {}) # The user's display name should be updated. - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/%s/displayname" % (self.banned_user_id,) ) self.assertEqual(channel.code, 200, channel.result) @@ -282,7 +282,7 @@ class ProfileTestCase(_ShadowBannedBase): ) # The update should succeed. - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room_id, self.banned_user_id), diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 0e96697f9b..227fffab58 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -86,7 +86,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): callback = Mock(spec=[], side_effect=check) current_rules_module().check_event_allowed = callback - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id, {}, @@ -104,7 +104,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEqual(ev.type, k[0]) self.assertEqual(ev.state_key, k[1]) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id, {}, @@ -123,7 +123,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): current_rules_module().check_event_allowed = check # now send the event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, {"x": "x"}, @@ -142,7 +142,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): current_rules_module().check_event_allowed = check # now send the event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, {"x": "x"}, @@ -152,7 +152,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): event_id = channel.json_body["event_id"] # ... and check that it got modified - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 7a2c653df8..edd1d184f8 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -91,7 +91,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # that we can make sure that the check is done on the whole alias. data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, 400, channel.result) @@ -104,7 +104,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # as cautious as possible here. data = {"room_alias_name": random_string(5)} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, 200, channel.result) @@ -118,7 +118,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): data = {"aliases": [self.random_alias(alias_length)]} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) @@ -128,7 +128,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 12a93f5687..0a5ca317ea 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -63,13 +63,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # implementation is now part of the r0 implementation, the newer # behaviour is used instead to be consistent with the r0 spec. # see issue #2602 - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) self.assertEquals(channel.code, 401, msg=channel.result) # valid token, expect content - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) self.assertEquals(channel.code, 200, msg=channel.result) @@ -87,7 +87,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ) # valid token, expect content - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) self.assertEquals(channel.code, 200, msg=channel.result) @@ -149,7 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/events/" + event_id, access_token=self.token, ) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 176ddf7ec9..041f2766df 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -63,7 +63,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -82,7 +82,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -108,7 +108,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -127,7 +127,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -153,7 +153,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -172,7 +172,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"403", channel.result) @@ -181,7 +181,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token - request, channel = self.make_request(b"GET", TEST_URL) + channel = self.make_request(b"GET", TEST_URL) self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") @@ -191,25 +191,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -223,9 +219,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -233,16 +227,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) def _delete_device(self, access_token, user_id, password, device_id): """Perform the UI-Auth to delete a device""" - request, channel = self.make_request( + channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) self.assertEquals(channel.code, 401, channel.result) @@ -262,7 +254,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "session": channel.json_body["session"], } - request, channel = self.make_request( + channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token, @@ -278,26 +270,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token = self.login("kermit", "monkey") # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) # Now try to hard logout this session - request, channel = self.make_request( - b"POST", "/logout", access_token=access_token - ) + channel = self.make_request(b"POST", "/logout", access_token=access_token) self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) @@ -308,26 +294,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token = self.login("kermit", "monkey") # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions - request, channel = self.make_request( - b"POST", "/logout/all", access_token=access_token - ) + channel = self.make_request(b"POST", "/logout/all", access_token=access_token) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -402,7 +382,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_ticket_url = urllib.parse.urlunparse(url_parts) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) + channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. self.assertEqual(channel.code, 200) @@ -446,7 +426,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) + channel = self.make_request("GET", cas_ticket_url) self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") @@ -472,7 +452,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) + channel = self.make_request("GET", cas_ticket_url) # Because the user is deactivated they are served an error template. self.assertEqual(channel.code, 403) @@ -502,7 +482,7 @@ class JWTTestCase(unittest.HomeserverTestCase): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) return channel def test_login_jwt_valid_registered(self): @@ -634,7 +614,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_token(self): params = json.dumps({"type": "org.matrix.login.jwt"}) - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -707,7 +687,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) return channel def test_login_jwt_valid(self): @@ -735,7 +715,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): ] def register_as_user(self, username): - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), {"username": username}, @@ -784,7 +764,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) @@ -799,7 +779,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": self.service.sender}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) @@ -814,7 +794,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) @@ -829,7 +809,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.another_service.token ) @@ -845,6 +825,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 11cd8efe21..94a5154834 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -53,7 +53,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} - request, channel = self.make_request( + channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) @@ -68,7 +68,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = False body = {"presence": "here", "status_msg": "beep boop"} - request, channel = self.make_request( + channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 2a3b483eaf..e59fa70baa 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -189,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.owner_tok = self.login("owner", "pass") def test_set_displayname(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test"}), @@ -202,7 +202,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test" * 100}), @@ -214,9 +214,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "owner") def get_displayname(self): - request, channel = self.make_request( - "GET", "/profile/%s/displayname" % (self.owner,) - ) + channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] @@ -278,7 +276,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): ) def request_profile(self, expected_code, url_suffix="", access_token=None): - request, channel = self.make_request( + channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) self.assertEqual(channel.code, expected_code, channel.result) @@ -320,19 +318,19 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester, access_token=self.requester_tok ) self.assertEqual(channel.code, 200, channel.result) - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester + "/displayname", access_token=self.requester_tok, ) self.assertEqual(channel.code, 200, channel.result) - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester + "/avatar_url", access_token=self.requester_tok, diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py index 7add5523c8..2bc512d75e 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -45,13 +45,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) @@ -74,13 +74,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # disable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": False}, @@ -89,26 +89,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # check rule disabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) self.assertEqual(channel.code, 200) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) @@ -130,13 +130,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # disable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": False}, @@ -145,14 +145,14 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # check rule disabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) # re-enable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": True}, @@ -161,7 +161,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # check rule enabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) @@ -182,32 +182,32 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) self.assertEqual(channel.code, 200) # check 404 for deleted rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) @@ -221,7 +221,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token ) self.assertEqual(channel.code, 404) @@ -235,7 +235,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": True}, @@ -252,7 +252,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/.m.muahahah/enabled", {"enabled": True}, @@ -276,13 +276,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET actions for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) self.assertEqual(channel.code, 200) @@ -305,13 +305,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # change the rule actions - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/actions", {"actions": ["dont_notify"]}, @@ -320,7 +320,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # GET actions for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) self.assertEqual(channel.code, 200) @@ -341,26 +341,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) self.assertEqual(channel.code, 200) # check 404 for deleted rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) @@ -374,7 +374,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token ) self.assertEqual(channel.code, 404) @@ -388,7 +388,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/actions", {"actions": ["dont_notify"]}, @@ -405,7 +405,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/.m.muahahah/actions", {"actions": ["dont_notify"]}, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 55d872f0ee..6105eac47c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -84,13 +84,13 @@ class RoomPermissionsTestCase(RoomBase): self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) ).encode("ascii") - request, channel = self.make_request( + channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) self.assertEquals(200, channel.code, channel.result) # set topic for public room - request, channel = self.make_request( + channel = self.make_request( "PUT", ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', @@ -112,7 +112,7 @@ class RoomPermissionsTestCase(RoomBase): ) # send message in uncreated room, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, @@ -120,24 +120,24 @@ class RoomPermissionsTestCase(RoomBase): self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): @@ -145,30 +145,30 @@ class RoomPermissionsTestCase(RoomBase): topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 @@ -176,29 +176,29 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, @@ -208,7 +208,7 @@ class RoomPermissionsTestCase(RoomBase): def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) - request, channel = self.make_request("GET", path) + channel = self.make_request("GET", path) self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): @@ -380,16 +380,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): - request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") - request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): @@ -398,17 +398,17 @@ class RoomsMemberListTestCase(RoomBase): room_path = "/rooms/%s/members" % room_id self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - request, channel = self.make_request("GET", room_path) + channel = self.make_request("GET", room_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined - request, channel = self.make_request("GET", room_path) + channel = self.make_request("GET", room_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - request, channel = self.make_request("GET", room_path) + channel = self.make_request("GET", room_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -419,30 +419,26 @@ class RoomsCreateTestCase(RoomBase): def test_post_room_no_keys(self): # POST with no config keys, expect new room id - request, channel = self.make_request("POST", "/createRoom", "{}") + channel = self.make_request("POST", "/createRoom", "{}") self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - request, channel = self.make_request( - "POST", "/createRoom", b'{"visibility":"private"}' - ) + channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - request, channel = self.make_request( - "POST", "/createRoom", b'{"custom":"stuff"}' - ) + channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - request, channel = self.make_request( + channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) self.assertEquals(200, channel.code) @@ -450,16 +446,16 @@ class RoomsCreateTestCase(RoomBase): def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - request, channel = self.make_request("POST", "/createRoom", b'{"visibili') + channel = self.make_request("POST", "/createRoom", b'{"visibili') self.assertEquals(400, channel.code) - request, channel = self.make_request("POST", "/createRoom", b'["hello"]') + channel = self.make_request("POST", "/createRoom", b'["hello"]') self.assertEquals(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # Note the trailing space in the MXID here! - request, channel = self.make_request( + channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) self.assertEquals(400, channel.code) @@ -477,54 +473,54 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json - request, channel = self.make_request("PUT", self.path, "{}") + channel = self.make_request("PUT", self.path, "{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + channel = self.make_request("PUT", self.path, '{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '{"nao') + channel = self.make_request("PUT", self.path, '{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, "text only") + channel = self.make_request("PUT", self.path, "text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, "") + channel = self.make_request("PUT", self.path, "") self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - request, channel = self.make_request("PUT", self.path, content) + channel = self.make_request("PUT", self.path, content) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there - request, channel = self.make_request("GET", self.path) + channel = self.make_request("GET", self.path) self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - request, channel = self.make_request("PUT", self.path, content) + channel = self.make_request("PUT", self.path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = self.make_request("GET", self.path) + channel = self.make_request("GET", self.path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - request, channel = self.make_request("PUT", self.path, content) + channel = self.make_request("PUT", self.path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = self.make_request("GET", self.path) + channel = self.make_request("GET", self.path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -540,24 +536,22 @@ class RoomMemberStateTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = self.make_request("PUT", path, "{}") + channel = self.make_request("PUT", path, "{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '{"_name":"bo"}') + channel = self.make_request("PUT", path, '{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '{"nao') + channel = self.make_request("PUT", path, '{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( - "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' - ) + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, "text only") + channel = self.make_request("PUT", path, "text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, "") + channel = self.make_request("PUT", path, "") self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types @@ -566,7 +560,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = self.make_request("PUT", path, content.encode("ascii")) + channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): @@ -577,10 +571,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = self.make_request("PUT", path, content.encode("ascii")) + channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} @@ -595,10 +589,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -614,10 +608,10 @@ class RoomMemberStateTestCase(RoomBase): Membership.INVITE, "Join us!", ) - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -668,7 +662,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id - request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + channel = self.make_request("PUT", path, {"displayname": "John Doe"}) self.assertEquals(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. @@ -678,7 +672,7 @@ class RoomJoinRatelimitTestCase(RoomBase): self.user_id, ) - request, channel = self.make_request("GET", path) + channel = self.make_request("GET", path) self.assertEquals(channel.code, 200) self.assertIn("displayname", channel.json_body) @@ -702,7 +696,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Make sure we send more requests than the rate-limiting config would allow # if all of these requests ended up joining the user to a room. for i in range(4): - request, channel = self.make_request("POST", path % room_id, {}) + channel = self.make_request("POST", path % room_id, {}) self.assertEquals(channel.code, 200) @unittest.override_config( @@ -731,42 +725,40 @@ class RoomMessagesTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = self.make_request("PUT", path, b"{}") + channel = self.make_request("PUT", path, b"{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') + channel = self.make_request("PUT", path, b'{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'{"nao') + channel = self.make_request("PUT", path, b'{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( - "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' - ) + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b"text only") + channel = self.make_request("PUT", path, b"text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b"") + channel = self.make_request("PUT", path, b"") self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -780,9 +772,7 @@ class RoomInitialSyncTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_initial_sync(self): - request, channel = self.make_request( - "GET", "/rooms/%s/initialSync" % self.room_id - ) + channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) @@ -823,7 +813,7 @@ class RoomMessageListTestCase(RoomBase): def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) self.assertEquals(200, channel.code) @@ -834,7 +824,7 @@ class RoomMessageListTestCase(RoomBase): def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) self.assertEquals(200, channel.code) @@ -867,7 +857,7 @@ class RoomMessageListTestCase(RoomBase): self.helper.send(self.room_id, "message 3") # Check that we get the first and second message when querying /messages. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -895,7 +885,7 @@ class RoomMessageListTestCase(RoomBase): # Check that we only get the second message through /message now that the first # has been purged. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -912,7 +902,7 @@ class RoomMessageListTestCase(RoomBase): # Check that we get no event, but also no error, when querying /messages with # the token that was pointing at the first event, because we don't have it # anymore. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -971,7 +961,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.helper.send(self.room, body="Hi!", tok=self.other_access_token) self.helper.send(self.room, body="There!", tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % (self.access_token,), { @@ -1000,7 +990,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.helper.send(self.room, body="Hi!", tok=self.other_access_token) self.helper.send(self.room, body="There!", tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % (self.access_token,), { @@ -1048,14 +1038,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): return self.hs def test_restricted_no_auth(self): - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401, channel.result) def test_restricted_auth(self): self.register_user("user", "pass") tok = self.login("user", "pass") - request, channel = self.make_request("GET", self.url, access_token=tok) + channel = self.make_request("GET", self.url, access_token=tok) self.assertEqual(channel.code, 200, channel.result) @@ -1083,7 +1073,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.displayname = "test user" data = {"displayname": self.displayname} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, @@ -1096,7 +1086,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): def test_per_room_profile_forbidden(self): data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id), @@ -1106,7 +1096,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, @@ -1139,7 +1129,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_join_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/join".format(self.room_id), content={"reason": reason}, @@ -1153,7 +1143,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), content={"reason": reason}, @@ -1167,7 +1157,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1181,7 +1171,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1193,7 +1183,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_unban_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1205,7 +1195,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_invite_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1224,7 +1214,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), content={"reason": reason}, @@ -1235,7 +1225,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) def _check_for_reason(self, reason): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( self.room_id, self.second_user_id @@ -1284,7 +1274,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by a label on a /context request.""" event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), @@ -1314,7 +1304,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by the absence of a label on a /context request.""" event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), @@ -1349,7 +1339,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """ event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), @@ -1377,7 +1367,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), @@ -1394,7 +1384,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), @@ -1417,7 +1407,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % ( @@ -1448,7 +1438,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) @@ -1483,7 +1473,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) @@ -1530,7 +1520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) @@ -1651,7 +1641,7 @@ class ContextTestCase(unittest.HomeserverTestCase): # Check that we can still see the messages before the erasure request. - request, channel = self.make_request( + channel = self.make_request( "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), @@ -1715,7 +1705,7 @@ class ContextTestCase(unittest.HomeserverTestCase): # Check that a user that joined the room after the erasure request can't see # the messages anymore. - request, channel = self.make_request( + channel = self.make_request( "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), @@ -1805,7 +1795,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" % (self.room_id,), @@ -1826,7 +1816,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) self.assertEqual(channel.code, expected_code, channel.result) @@ -1856,14 +1846,14 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) self.assertEqual(channel.code, expected_code, channel.result) def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "GET", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), access_token=self.room_owner_tok, @@ -1875,7 +1865,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), json.dumps(content), diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index ae0207366b..38c51525a3 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -94,7 +94,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user="@jim:red") def test_set_typing(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', @@ -117,7 +117,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ) def test_set_not_typing(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', @@ -125,7 +125,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) def test_typing_timeout(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', @@ -138,7 +138,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 2) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 5a18af8d34..3d46117e65 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -81,7 +81,7 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "POST", @@ -157,7 +157,7 @@ class RestHelper: data = {"membership": membership} data.update(extra_data) - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "PUT", @@ -192,7 +192,7 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "PUT", @@ -248,9 +248,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - _, channel = make_request( - self.hs.get_reactor(), self.site, method, path, content - ) + channel = make_request(self.hs.get_reactor(), self.site, method, path, content) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -333,7 +331,7 @@ class RestHelper: """ image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), FakeSite(resource), "POST", @@ -366,7 +364,7 @@ class RestHelper: client_redirect_url = "https://x" # first hit the redirect url (which will issue a cookie and state) - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "GET", @@ -411,7 +409,7 @@ class RestHelper: with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): # now hit the callback URI with the right params and a made-up code - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "GET", @@ -434,7 +432,7 @@ class RestHelper: # finally, submit the matrix login token to the login API, which gives us our # matrix access token and device id. - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "POST", diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 8a898be24c..cb87b80e33 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -240,7 +240,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) def _request_token(self, email, client_secret): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, @@ -254,7 +254,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") # Load the password reset confirmation page - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.submit_token_resource), "GET", @@ -268,7 +268,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # password reset confirm button # Confirm the password reset - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.submit_token_resource), "POST", @@ -302,7 +302,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): def _reset_password( self, new_password, session_id, client_secret, expected_code=200 ): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/password", { @@ -344,7 +344,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) # Check that this access token has been invalidated. - request, channel = self.make_request("GET", "account/whoami") + channel = self.make_request("GET", "account/whoami") self.assertEqual(channel.code, 401) def test_pending_invites(self): @@ -399,7 +399,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): "erase": False, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) self.assertEqual(channel.code, 200) @@ -522,7 +522,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self._validate_token(link) - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -540,7 +540,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -561,7 +561,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/delete", {"medium": "email", "address": self.email}, @@ -570,7 +570,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -593,7 +593,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/delete", {"medium": "email", "address": self.email}, @@ -604,7 +604,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -621,7 +621,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEquals(len(self.email_attempts), 1) # Attempt to add email without clicking the link - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -639,7 +639,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -654,7 +654,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): session_id = "weasle" # Attempt to add email without even requesting an email - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -672,7 +672,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -776,9 +776,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if next_link: body["next_link"] = next_link - request, channel = self.make_request( - "POST", b"account/3pid/email/requestToken", body, - ) + channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) self.assertEquals(expect_code, channel.code, channel.result) return channel.json_body.get("sid") @@ -786,7 +784,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", ): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, @@ -799,7 +797,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Remove the host path = link.replace("https://example.com", "") - request, channel = self.make_request("GET", path, shorthand=False) + channel = self.make_request("GET", path, shorthand=False) self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -833,7 +831,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self._validate_token(link) - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -851,7 +849,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a35c2646fb..da866c8b44 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -67,9 +67,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): def register(self, expected_response: int, body: JsonDict) -> FakeChannel: """Make a register request.""" - request, channel = self.make_request( - "POST", "register", body - ) # type: SynapseRequest, FakeChannel + channel = self.make_request("POST", "register", body) self.assertEqual(channel.code, expected_response) return channel @@ -81,12 +79,12 @@ class FallbackAuthTests(unittest.HomeserverTestCase): if post_session is None: post_session = session - request, channel = self.make_request( + channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) # type: SynapseRequest, FakeChannel + ) self.assertEqual(channel.code, 200) - request, channel = self.make_request( + channel = self.make_request( "POST", "auth/m.login.recaptcha/fallback/web?session=" + post_session @@ -184,7 +182,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def get_device_ids(self, access_token: str) -> List[str]: # Get the list of devices so one can be deleted. - _, channel = self.make_request("GET", "devices", access_token=access_token,) + channel = self.make_request("GET", "devices", access_token=access_token,) self.assertEqual(channel.code, 200) return [d["device_id"] for d in channel.json_body["devices"]] @@ -196,9 +194,9 @@ class UIAuthTests(unittest.HomeserverTestCase): body: Union[bytes, JsonDict] = b"", ) -> FakeChannel: """Delete an individual device.""" - request, channel = self.make_request( + channel = self.make_request( "DELETE", "devices/" + device, body, access_token=access_token, - ) # type: SynapseRequest, FakeChannel + ) # Ensure the response is sane. self.assertEqual(channel.code, expected_response) @@ -209,9 +207,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """Delete 1 or more devices.""" # Note that this uses the delete_devices endpoint so that we can modify # the payload half-way through some tests. - request, channel = self.make_request( + channel = self.make_request( "POST", "delete_devices", body, access_token=self.user_tok, - ) # type: SynapseRequest, FakeChannel + ) # Ensure the response is sane. self.assertEqual(channel.code, expected_response) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index 767e126875..e808339fb3 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -36,7 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): return hs def test_check_auth_required(self): - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) @@ -44,7 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.register_user("user", "pass") access_token = self.login("user", "pass") - request, channel = self.make_request("GET", self.url, access_token=access_token) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -62,7 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): user = self.register_user(localpart, password) access_token = self.login(user, password) - request, channel = self.make_request("GET", self.url, access_token=access_token) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -70,7 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) - request, channel = self.make_request("GET", self.url, access_token=access_token) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 231d5aefea..f761c44936 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -36,7 +36,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastore() def test_add_filter(self): - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, @@ -49,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, @@ -61,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, @@ -79,7 +79,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.reactor.advance(1) filter_id = filter_id.result - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) @@ -87,7 +87,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) @@ -97,7 +97,7 @@ class FilterTestCase(unittest.HomeserverTestCase): # Currently invalid params do not have an appropriate errcode # in errors.py def test_get_filter_invalid_id(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) @@ -105,7 +105,7 @@ class FilterTestCase(unittest.HomeserverTestCase): # No ID also returns an invalid_id error def test_get_filter_no_id(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index ee86b94917..fba34def30 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -70,9 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_get_policy(self): """Tests if the /password_policy endpoint returns the configured policy.""" - request, channel = self.make_request( - "GET", "/_matrix/client/r0/password_policy" - ) + channel = self.make_request("GET", "/_matrix/client/r0/password_policy") self.assertEqual(channel.code, 200, channel.result) self.assertEqual( @@ -89,7 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_too_short(self): request_data = json.dumps({"username": "kermit", "password": "shorty"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -98,7 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_digit(self): request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -107,7 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_symbol(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -116,7 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_uppercase(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -125,7 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_lowercase(self): request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -134,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_compliant(self): request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. @@ -160,7 +158,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): }, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/account/password", request_data, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 4bf3e0d63b..27db4f551e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -61,7 +61,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.get_datastore().services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) - request, channel = self.make_request( + channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -72,7 +72,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists request_data = json.dumps({"username": "kermit"}) - request, channel = self.make_request( + channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -80,14 +80,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid username") @@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "auth": {"type": LoginType.DUMMY}, } request_data = json.dumps(params) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) det_data = { "user_id": user_id, @@ -117,7 +117,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") @@ -127,7 +127,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) @@ -136,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_disabled_guest_registration(self): self.hs.config.allow_guest_access = False - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") @@ -145,7 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_ratelimiting_guest(self): for i in range(0, 6): url = self.url + b"?kind=guest" - request, channel = self.make_request(b"POST", url, b"{}") + channel = self.make_request(b"POST", url, b"{}") if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -155,7 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"200", channel.result) @@ -169,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "auth": {"type": LoginType.DUMMY}, } request_data = json.dumps(params) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -179,12 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"200", channel.result) def test_advertised_flows(self): - request, channel = self.make_request(b"POST", self.url, b"{}") + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -207,7 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } ) def test_advertised_flows_captcha_and_terms_and_3pids(self): - request, channel = self.make_request(b"POST", self.url, b"{}") + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -239,7 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } ) def test_advertised_flows_no_msisdn_email_required(self): - request, channel = self.make_request(b"POST", self.url, b"{}") + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -279,7 +279,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, @@ -318,13 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( @@ -346,14 +346,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/account_validity/validity" params = {"user_id": user_id} request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) def test_manual_expire(self): @@ -370,14 +368,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): "enable_renewal_emails": False, } request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result @@ -397,20 +393,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): "enable_renewal_emails": False, } request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # Try to log the user out - request, channel = self.make_request(b"POST", "/logout", access_token=tok) + channel = self.make_request(b"POST", "/logout", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions - request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) + channel = self.make_request(b"POST", "/logout/all", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -484,7 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # retrieve the token from the DB. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - request, channel = self.make_request(b"GET", url) + channel = self.make_request(b"GET", url) self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. @@ -504,14 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # our access token should be denied from now, otherwise they should # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) def test_renewal_invalid_token(self): # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" - request, channel = self.make_request(b"GET", url) + channel = self.make_request(b"GET", url) self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. @@ -532,7 +526,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] (user_id, tok) = self.create_user() - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, @@ -556,7 +550,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "erase": False, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) self.assertEqual(channel.code, 200) @@ -607,7 +601,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] # Test that we're still able to manually trigger a mail to be sent. - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 6cd4eb6624..bd574077e7 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -60,7 +60,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, @@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" % (self.room, self.parent_id), @@ -152,7 +152,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" % (self.room, self.parent_id, from_token), @@ -210,7 +210,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" % (self.room, self.parent_id, from_token), @@ -279,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s" "/aggregations/%s/%s/m.reaction/%s?limit=1%s" @@ -325,7 +325,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), @@ -357,7 +357,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), access_token=self.user_token, @@ -365,7 +365,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), @@ -382,7 +382,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that aggregations must be annotations. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" % (self.room, self.parent_id, RelationTypes.REPLACE), @@ -414,7 +414,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, @@ -450,7 +450,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, @@ -507,7 +507,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, @@ -549,7 +549,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Check the relation is returned - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), @@ -561,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(len(channel.json_body["chunk"]), 1) # Redact the original event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), @@ -571,7 +571,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), @@ -598,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Redact the original - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % ( @@ -612,7 +612,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Check that aggregations returns zero - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" % (self.room, original_event_id), @@ -656,7 +656,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): original_id = parent_id if parent_id else self.parent_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index 05c5ee5a78..116ace1812 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -42,13 +42,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.handler = hs.get_user_directory_handler() def _get_shared_rooms(self, token, other_user) -> FakeChannel: - _, channel = self.make_request( + return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" % other_user, access_token=token, ) - return channel def test_shared_room_list_public(self): """ diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 31ac0fccb8..512e36c236 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -35,7 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ] def test_sync_argless(self): - request, channel = self.make_request("GET", "/sync") + channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) self.assertTrue( @@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase): """ self.hs.config.use_presence = False - request, channel = self.make_request("GET", "/sync") + channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) self.assertTrue( @@ -194,7 +194,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): tok=tok, ) - request, channel = self.make_request( + channel = self.make_request( "GET", "/sync?filter=%s" % sync_filter, access_token=tok ) self.assertEqual(channel.code, 200, channel.result) @@ -245,21 +245,19 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) # Start typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) self.assertEquals(200, channel.code) - request, channel = self.make_request( - "GET", "/sync?access_token=%s" % (access_token,) - ) + channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', @@ -267,7 +265,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) # Start typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', @@ -275,9 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) # Should return immediately - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -289,9 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # invalidate the stream token. self.helper.send(room, body="There!", tok=other_access_token) - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -299,9 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # ahead, and therefore it's saying the typing (that we've actually # already seen) is new, since it's got a token above our new, now-reset # stream token. - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -383,7 +375,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/read_markers" % self.room_id, body, @@ -450,7 +442,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def _check_unread_count(self, expected_count: True): """Syncs and compares the unread count with the expected value.""" - request, channel = self.make_request( + channel = self.make_request( "GET", self.url % self.next_batch, access_token=self.tok, ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6f0677d335..ae2b32b131 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -228,7 +228,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _req(self, content_disposition): - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", @@ -324,7 +324,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.thumbnail_resource), "GET", diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 529b6bcded..83d728b4a4 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -113,7 +113,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): def test_cache_returns_correct_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -138,7 +138,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Check the cache returns the correct response - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False ) @@ -154,7 +154,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertNotIn("http://matrix.org", self.preview_url._cache) # Check the database cache returns the correct response - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False ) @@ -175,7 +175,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -210,7 +210,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -245,7 +245,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -278,7 +278,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False, @@ -308,7 +308,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -329,7 +329,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -346,7 +346,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Blacklisted IP addresses, accessed directly, are not spidered. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://192.168.1.1", shorthand=False ) @@ -365,7 +365,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Blacklisted IP ranges, accessed directly, are not spidered. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://1.1.1.2", shorthand=False ) @@ -385,7 +385,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False, @@ -422,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): (IPv4Address, "10.1.2.3"), ] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) self.assertEqual(channel.code, 502) @@ -442,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") ] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -463,7 +463,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -480,7 +480,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ OPTIONS returns the OPTIONS. """ - request, channel = self.make_request( + channel = self.make_request( "OPTIONS", "preview_url?url=http://example.com", shorthand=False ) self.assertEqual(channel.code, 200) @@ -493,7 +493,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False, @@ -567,7 +567,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, @@ -632,7 +632,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): } end_content = json.dumps(result).encode("utf-8") - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index aaf2fb821b..32acd93dc1 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -25,7 +25,7 @@ class HealthCheckTests(unittest.HomeserverTestCase): return HealthResource() def test_health(self): - request, channel = self.make_request("GET", "/health", shorthand=False) + channel = self.make_request("GET", "/health", shorthand=False) self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 17ded96b9c..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -28,7 +28,7 @@ class WellKnownTests(unittest.HomeserverTestCase): self.hs.config.public_baseurl = "https://tesths" self.hs.config.default_identity_server = "https://testis" - request, channel = self.make_request( + channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) @@ -44,7 +44,7 @@ class WellKnownTests(unittest.HomeserverTestCase): def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None - request, channel = self.make_request( + channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) diff --git a/tests/server.py b/tests/server.py index 4faf32e335..7d1ad362c4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -174,11 +174,11 @@ def make_request( custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, -): +) -> FakeChannel: """ Make a web request using the given method, path and content, and render it - Returns the Request and the Channel underneath. + Returns the fake Channel object which records the response to the request. Args: site: The twisted Site to use to render the request @@ -202,7 +202,7 @@ def make_request( is finished. Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + channel """ if not isinstance(method, bytes): method = method.encode("ascii") @@ -265,7 +265,7 @@ def make_request( if await_result: channel.await_result() - return req, channel + return channel @implementer(IReactorPluggableNameResolver) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index e0a9cd93ac..4dd5a36178 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -70,7 +70,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): the notice URL + an authentication code. """ # Initial sync, to get the user consent room invite - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) self.assertEqual(channel.code, 200) @@ -79,7 +79,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): room_id = list(channel.json_body["rooms"]["invite"].keys())[0] # Join the room - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/" + room_id + "/join", access_token=self.access_token, @@ -87,7 +87,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) # Sync again, to get the message in the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) self.assertEqual(channel.code, 200) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 9c8027a5b2..fea54464af 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -305,7 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.register_user("user", "password") tok = self.login("user", "password") - request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) invites = channel.json_body["rooms"]["invite"] self.assertEqual(len(invites), 0, invites) @@ -318,7 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Sync again to retrieve the events in the room, so we can check whether this # room has a notice in it. - request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) # Scan the events in the room to search for a message from the server notices # user. @@ -353,9 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): tok = self.login(localpart, "password") # Sync with the user's token to mark the user as active. - request, channel = self.make_request( - "GET", "/sync?timeout=0", access_token=tok, - ) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,) # Also retrieves the list of invites for this user. We don't care about that # one except if we're processing the last user, which should have received an diff --git a/tests/test_mau.py b/tests/test_mau.py index c5ec6396a7..02e56e1b0b 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -201,7 +201,7 @@ class TestMauLimit(unittest.HomeserverTestCase): } ) - request, channel = self.make_request("POST", "/register", request_data) + channel = self.make_request("POST", "/register", request_data) if channel.code != 200: raise HttpResponseException( @@ -213,7 +213,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request("GET", "/sync", access_token=token) + channel = self.make_request("GET", "/sync", access_token=token) if channel.code != 200: raise HttpResponseException( diff --git a/tests/test_server.py b/tests/test_server.py index 0be8c7e2fc..815da18e65 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -84,7 +84,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -108,7 +108,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -126,7 +126,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -148,9 +148,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar" - ) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -172,7 +170,7 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -204,7 +202,7 @@ class OptionsResourceTests(unittest.TestCase): ) # render the request and return the channel - _, channel = make_request(self.reactor, site, method, path, shorthand=False) + channel = make_request(self.reactor, site, method, path, shorthand=False) return channel def test_unknown_options_request(self): @@ -277,7 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -295,7 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -316,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -335,7 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 71580b454d..a743cdc3a9 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request request_data = json.dumps({"username": "kermit", "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"401", channel.result) @@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase): self.registration_handler.check_username = Mock(return_value=True) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. @@ -113,7 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase): }, } ) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. diff --git a/tests/unittest.py b/tests/unittest.py index 102b0a1f34..16fd4f32b9 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -372,38 +372,6 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest - # when the `request` arg isn't given, so we define an explicit override to - # cover that case. - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[SynapseRequest, FakeChannel]: - ... - - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[T, FakeChannel]: - ... - def make_request( self, method: Union[bytes, str], @@ -415,7 +383,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, - ) -> Tuple[T, FakeChannel]: + ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the given content. @@ -438,7 +406,7 @@ class HomeserverTestCase(TestCase): tells the channel the request is finished. Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + The FakeChannel object which stores the result of the request. """ return make_request( self.reactor, @@ -568,7 +536,7 @@ class HomeserverTestCase(TestCase): self.hs.config.registration_shared_secret = "shared" # Create the user - request, channel = self.make_request("GET", "/_synapse/admin/v1/register") + channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -593,7 +561,7 @@ class HomeserverTestCase(TestCase): "inhibit_login": True, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_synapse/admin/v1/register", body.encode("utf8") ) self.assertEqual(channel.code, 200, channel.json_body) @@ -611,7 +579,7 @@ class HomeserverTestCase(TestCase): if device_id: body["device_id"] = device_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 200, channel.result) @@ -679,7 +647,7 @@ class HomeserverTestCase(TestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 403, channel.result) -- cgit 1.5.1 From c9dd47d66864714043b51a235909e3e957740c7b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 22:28:06 +0000 Subject: lint --- tests/replication/test_auth.py | 2 -- tests/replication/test_client_reader_shard.py | 3 +-- tests/rest/client/v1/test_login.py | 2 +- tests/rest/client/v2_alpha/test_auth.py | 1 - tests/unittest.py | 2 +- 5 files changed, 3 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index c5ab3032a5..f35a5235e1 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Tuple -from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index abcc74f932..4608b65a0c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -14,11 +14,10 @@ # limitations under the License. import logging -from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.server import FakeChannel, make_request +from tests.server import make_request logger = logging.getLogger(__name__) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 041f2766df..566776e97e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -715,7 +715,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): ] def register_as_user(self, username): - channel = self.make_request( + self.make_request( b"POST", "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), {"username": username}, diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index da866c8b44..51323b3da3 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -20,7 +20,6 @@ from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.http.site import SynapseRequest from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register from synapse.rest.oidc import OIDCResource diff --git a/tests/unittest.py b/tests/unittest.py index 16fd4f32b9..39e5e7b85c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Dict, Optional, Type, TypeVar, Union from mock import Mock, patch -- cgit 1.5.1 From 2dd2e90e2b0b78f82d0e3372bacba9a84240302b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 12:39:56 +0000 Subject: Test `get_extra_attributes` fallback despite the warnings saying "don't implement get_extra_attributes", we had implemented it, so the tests weren't doing what we thought they were. --- tests/handlers/test_oidc.py | 32 +++++++++++++++++++++----------- tests/rest/client/v1/utils.py | 2 +- 2 files changed, 22 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 464e569ac8..1794a169e0 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -19,7 +19,7 @@ from mock import ANY, Mock, patch import pymacaroons -from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider +from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException from synapse.types import UserID @@ -55,11 +55,14 @@ COOKIE_NAME = b"oidc_session" COOKIE_PATH = "/_synapse/oidc" -class TestMappingProvider(OidcMappingProvider): +class TestMappingProvider: @staticmethod def parse_config(config): return + def __init__(self, config): + pass + def get_remote_user_id(self, userinfo): return userinfo["sub"] @@ -360,6 +363,13 @@ class OidcHandlerTestCase(HomeserverTestCase): - when the userinfo fetching fails - when the code exchange fails """ + + # ensure that we are correctly testing the fallback when "get_extra_attributes" + # is not implemented. + mapping_provider = self.handler._user_mapping_provider + with self.assertRaises(AttributeError): + _ = mapping_provider.get_extra_attributes + token = { "type": "bearer", "id_token": "id_token", @@ -396,7 +406,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, {}, + expected_user_id, request, client_redirect_url, None, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -427,7 +437,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, {}, + expected_user_id, request, client_redirect_url, None, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() @@ -616,7 +626,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self._make_callback_with_userinfo(userinfo) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, {} + "@test_user:test", ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -627,7 +637,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self._make_callback_with_userinfo(userinfo) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, {} + "@test_user_2:test", ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -664,14 +674,14 @@ class OidcHandlerTestCase(HomeserverTestCase): } self._make_callback_with_userinfo(userinfo) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, {}, + user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self._make_callback_with_userinfo(userinfo) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, {}, + user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -686,7 +696,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self._make_callback_with_userinfo(userinfo) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, {}, + user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -722,7 +732,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self._make_callback_with_userinfo(userinfo) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, {}, + "@TEST_USER_2:test", ANY, ANY, None, ) def test_map_userinfo_to_invalid_localpart(self): @@ -756,7 +766,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, {}, + "@test_user1:test", ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 5a18af8d34..1b31669489 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -445,7 +445,7 @@ class RestHelper: return channel.json_body -# an 'oidc_config' suitable for login_with_oidc. +# an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_CONFIG = { "enabled": True, "discover": False, -- cgit 1.5.1 From c1883f042d4e6d69e4c211bcad5d65da5123f33d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 23:00:03 +0000 Subject: Remove spurious mocking of complete_sso_login The tests that need this all do it already. --- tests/handlers/test_oidc.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1794a169e0..bd24375018 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -796,8 +796,6 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._exchange_code = simple_async_mock(return_value={}) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() state = "state" session = self.handler._generate_oidc_session_token( -- cgit 1.5.1 From 8388a7fb3a02e50cd2dded8f7e43235c42ac597b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 15 Dec 2020 13:03:31 +0000 Subject: Make `_make_callback_with_userinfo` async ... so that we can test its behaviour when it raises. Also pull it out to the top level so that I can use it from other test classes. --- tests/handlers/test_oidc.py | 151 ++++++++++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 68 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index bd24375018..c54f1c5797 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -21,6 +21,7 @@ import pymacaroons from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException +from synapse.server import HomeServer from synapse.types import UserID from tests.test_utils import FakeResponse, simple_async_mock @@ -399,7 +400,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request = self._build_callback_request( + request = _build_callback_request( code, state, session, user_agent=user_agent, ip_address=ip_address ) @@ -607,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request = self._build_callback_request("code", state, session) + request = _build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) @@ -624,7 +625,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test_user", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( "@test_user:test", ANY, ANY, None, ) @@ -635,7 +636,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": 1234, "username": "test_user_2", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( "@test_user_2:test", ANY, ANY, None, ) @@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", @@ -672,14 +673,14 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( user.to_string(), ANY, ANY, None, ) @@ -694,7 +695,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( user.to_string(), ANY, ANY, None, ) @@ -715,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( @@ -730,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( "@TEST_USER_2:test", ANY, ANY, None, ) def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" - self._make_callback_with_userinfo({"sub": "test2", "username": "föö"}) + self.get_success( + _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) + ) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -762,7 +765,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( @@ -784,68 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) - def _make_callback_with_userinfo( - self, userinfo: dict, client_redirect_url: str = "http://client/redirect" - ) -> None: - self.handler._exchange_code = simple_async_mock(return_value={}) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - state = "state" - session = self.handler._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, - ) - request = self._build_callback_request("code", state, session) +async def _make_callback_with_userinfo( + hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" +) -> None: + """Mock up an OIDC callback with the given userinfo dict - self.get_success(self.handler.handle_oidc_callback(request)) + We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, + and poke in the userinfo dict as if it were the response to an OIDC userinfo call. - def _build_callback_request( - self, - code: str, - state: str, - session: str, - user_agent: str = "Browser", - ip_address: str = "10.0.0.1", - ): - """Builds a fake SynapseRequest to mock the browser callback - - Returns a Mock object which looks like the SynapseRequest we get from a browser - after SSO (before we return to the client) - - Args: - code: the authorization code which would have been returned by the OIDC - provider - state: the "state" param which would have been passed around in the - query param. Should be the same as was embedded in the session in - _build_oidc_session. - session: the "session" which would have been passed around in the cookie. - user_agent: the user-agent to present - ip_address: the IP address to pretend the request came from - """ - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + Args: + hs: the HomeServer impl to send the callback to. + userinfo: the OIDC userinfo dict + client_redirect_url: the URL to redirect to on success. + """ + handler = hs.get_oidc_handler() + handler._exchange_code = simple_async_mock(return_value={}) + handler._parse_id_token = simple_async_mock(return_value=userinfo) + handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - request.getCookie.return_value = session - request.args = {} - request.args[b"code"] = [code.encode("utf-8")] - request.args[b"state"] = [state.encode("utf-8")] - request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent - return request + state = "state" + session = handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + request = _build_callback_request("code", state, session) + + await handler.handle_oidc_callback(request) + + +def _build_callback_request( + code: str, + state: str, + session: str, + user_agent: str = "Browser", + ip_address: str = "10.0.0.1", +): + """Builds a fake SynapseRequest to mock the browser callback + + Returns a Mock object which looks like the SynapseRequest we get from a browser + after SSO (before we return to the client) + + Args: + code: the authorization code which would have been returned by the OIDC + provider + state: the "state" param which would have been passed around in the + query param. Should be the same as was embedded in the session in + _build_oidc_session. + session: the "session" which would have been passed around in the cookie. + user_agent: the user-agent to present + ip_address: the IP address to pretend the request came from + """ + request = Mock( + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] + ) + + request.getCookie.return_value = session + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent + return request -- cgit 1.5.1 From bd30cfe86a5413191fe44d8f937a00117334ea82 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 11:25:30 -0500 Subject: Convert internal pusher dicts to attrs classes. (#8940) This improves type hinting and should use less memory. --- changelog.d/8940.misc | 1 + mypy.ini | 1 + synapse/push/__init__.py | 60 +++++++-- synapse/push/emailpusher.py | 27 +++-- synapse/push/httppusher.py | 36 +++--- synapse/push/pusher.py | 24 ++-- synapse/push/pusherpool.py | 135 +++++++++++---------- .../slave/storage/_slaved_id_tracker.py | 20 ++- synapse/replication/slave/storage/pushers.py | 17 ++- synapse/rest/admin/users.py | 16 +-- synapse/rest/client/v1/pusher.py | 15 +-- synapse/storage/databases/main/__init__.py | 3 - synapse/storage/databases/main/pusher.py | 93 ++++++++------ synapse/storage/util/id_generators.py | 4 +- tests/push/test_email.py | 6 +- tests/push/test_http.py | 10 +- tests/rest/admin/test_user.py | 2 +- 17 files changed, 266 insertions(+), 204 deletions(-) create mode 100644 changelog.d/8940.misc (limited to 'tests') diff --git a/changelog.d/8940.misc b/changelog.d/8940.misc new file mode 100644 index 0000000000..4ff0b94b94 --- /dev/null +++ b/changelog.d/8940.misc @@ -0,0 +1 @@ +Add type hints to push module. diff --git a/mypy.ini b/mypy.ini index 334e3a22fb..1904204025 100644 --- a/mypy.ini +++ b/mypy.ini @@ -65,6 +65,7 @@ files = synapse/state, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index ad07ee86f6..9e7ac149a1 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -14,24 +14,70 @@ # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional -from synapse.types import RoomStreamToken +import attr + +from synapse.types import JsonDict, RoomStreamToken if TYPE_CHECKING: from synapse.app.homeserver import HomeServer +@attr.s(slots=True) +class PusherConfig: + """Parameters necessary to configure a pusher.""" + + id = attr.ib(type=Optional[str]) + user_name = attr.ib(type=str) + access_token = attr.ib(type=Optional[int]) + profile_tag = attr.ib(type=str) + kind = attr.ib(type=str) + app_id = attr.ib(type=str) + app_display_name = attr.ib(type=str) + device_display_name = attr.ib(type=str) + pushkey = attr.ib(type=str) + ts = attr.ib(type=int) + lang = attr.ib(type=Optional[str]) + data = attr.ib(type=Optional[JsonDict]) + last_stream_ordering = attr.ib(type=Optional[int]) + last_success = attr.ib(type=Optional[int]) + failing_since = attr.ib(type=Optional[int]) + + def as_dict(self) -> Dict[str, Any]: + """Information that can be retrieved about a pusher after creation.""" + return { + "app_display_name": self.app_display_name, + "app_id": self.app_id, + "data": self.data, + "device_display_name": self.device_display_name, + "kind": self.kind, + "lang": self.lang, + "profile_tag": self.profile_tag, + "pushkey": self.pushkey, + } + + +@attr.s(slots=True) +class ThrottleParams: + """Parameters for controlling the rate of sending pushes via email.""" + + last_sent_ts = attr.ib(type=int) + throttle_ms = attr.ib(type=int) + + class Pusher(metaclass=abc.ABCMeta): - def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]): + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): self.hs = hs self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.pusher_id = pusherdict["id"] - self.user_id = pusherdict["user_name"] - self.app_id = pusherdict["app_id"] - self.pushkey = pusherdict["pushkey"] + self.pusher_id = pusher_config.id + self.user_id = pusher_config.user_name + self.app_id = pusher_config.app_id + self.pushkey = pusher_config.pushkey + + self.last_stream_ordering = pusher_config.last_stream_ordering # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 11a97b8df4..d2eff75a58 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -14,13 +14,13 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from twisted.internet.base import DelayedCall from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher +from synapse.push import Pusher, PusherConfig, ThrottleParams from synapse.push.mailer import Mailer if TYPE_CHECKING: @@ -60,15 +60,14 @@ class EmailPusher(Pusher): factor out the common parts """ - def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer): - super().__init__(hs, pusherdict) + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer): + super().__init__(hs, pusher_config) self.mailer = mailer self.store = self.hs.get_datastore() - self.email = pusherdict["pushkey"] - self.last_stream_ordering = pusherdict["last_stream_ordering"] + self.email = pusher_config.pushkey self.timed_call = None # type: Optional[DelayedCall] - self.throttle_params = {} # type: Dict[str, Dict[str, int]] + self.throttle_params = {} # type: Dict[str, ThrottleParams] self._inited = False self._is_processing = False @@ -132,6 +131,7 @@ class EmailPusher(Pusher): if not self._inited: # this is our first loop: load up the throttle params + assert self.pusher_id is not None self.throttle_params = await self.store.get_throttle_params_by_room( self.pusher_id ) @@ -157,6 +157,7 @@ class EmailPusher(Pusher): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering + assert start is not None unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( self.user_id, start, self.max_stream_ordering ) @@ -244,13 +245,13 @@ class EmailPusher(Pusher): def get_room_throttle_ms(self, room_id: str) -> int: if room_id in self.throttle_params: - return self.throttle_params[room_id]["throttle_ms"] + return self.throttle_params[room_id].throttle_ms else: return 0 def get_room_last_sent_ts(self, room_id: str) -> int: if room_id in self.throttle_params: - return self.throttle_params[room_id]["last_sent_ts"] + return self.throttle_params[room_id].last_sent_ts else: return 0 @@ -301,10 +302,10 @@ class EmailPusher(Pusher): new_throttle_ms = min( current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) - self.throttle_params[room_id] = { - "last_sent_ts": self.clock.time_msec(), - "throttle_ms": new_throttle_ms, - } + self.throttle_params[room_id] = ThrottleParams( + self.clock.time_msec(), new_throttle_ms, + ) + assert self.pusher_id is not None await self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e8b25bcd2a..417fe0f1f5 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher, PusherConfigException +from synapse.push import Pusher, PusherConfig, PusherConfigException from . import push_rule_evaluator, push_tools @@ -62,33 +62,29 @@ class HttpPusher(Pusher): # This one's in ms because we compare it against the clock GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]): - super().__init__(hs, pusherdict) + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): + super().__init__(hs, pusher_config) self.storage = self.hs.get_storage() - self.app_display_name = pusherdict["app_display_name"] - self.device_display_name = pusherdict["device_display_name"] - self.pushkey_ts = pusherdict["ts"] - self.data = pusherdict["data"] - self.last_stream_ordering = pusherdict["last_stream_ordering"] + self.app_display_name = pusher_config.app_display_name + self.device_display_name = pusher_config.device_display_name + self.pushkey_ts = pusher_config.ts + self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.failing_since = pusherdict["failing_since"] + self.failing_since = pusher_config.failing_since self.timed_call = None self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room - if "data" not in pusherdict: - raise PusherConfigException("No 'data' key for HTTP pusher") - self.data = pusherdict["data"] + self.data = pusher_config.data + if self.data is None: + raise PusherConfigException("'data' key can not be null for HTTP pusher") self.name = "%s/%s/%s" % ( - pusherdict["user_name"], - pusherdict["app_id"], - pusherdict["pushkey"], + pusher_config.user_name, + pusher_config.app_id, + pusher_config.pushkey, ) - if self.data is None: - raise PusherConfigException("data can not be null for HTTP pusher") - # Validate that there's a URL and it is of the proper form. if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") @@ -180,6 +176,7 @@ class HttpPusher(Pusher): Never call this directly: use _process which will only allow this to run once per pusher. """ + assert self.last_stream_ordering is not None unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -208,6 +205,7 @@ class HttpPusher(Pusher): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] + assert self.last_stream_ordering is not None pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, @@ -314,6 +312,8 @@ class HttpPusher(Pusher): # or may do so (i.e. is encrypted so has unknown effects). priority = "high" + # This was checked in the __init__, but mypy doesn't seem to know that. + assert self.data is not None if self.data.get("format") == "event_id_only": d = { "notification": { diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8f1072b094..2aa7918fb4 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -14,9 +14,9 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Callable, Dict, Optional -from synapse.push import Pusher +from synapse.push import Pusher, PusherConfig from synapse.push.emailpusher import EmailPusher from synapse.push.httppusher import HttpPusher from synapse.push.mailer import Mailer @@ -34,7 +34,7 @@ class PusherFactory: self.pusher_types = { "http": HttpPusher - } # type: Dict[str, Callable[[HomeServer, dict], Pusher]] + } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: @@ -47,18 +47,18 @@ class PusherFactory: logger.info("defined email pusher type") - def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]: - kind = pusherdict["kind"] + def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: + kind = pusher_config.kind f = self.pusher_types.get(kind, None) if not f: return None - logger.debug("creating %s pusher for %r", kind, pusherdict) - return f(self.hs, pusherdict) + logger.debug("creating %s pusher for %r", kind, pusher_config) + return f(self.hs, pusher_config) def _create_email_pusher( - self, _hs: "HomeServer", pusherdict: Dict[str, Any] + self, _hs: "HomeServer", pusher_config: PusherConfig ) -> EmailPusher: - app_name = self._app_name_from_pusherdict(pusherdict) + app_name = self._app_name_from_pusherdict(pusher_config) mailer = self.mailers.get(app_name) if not mailer: mailer = Mailer( @@ -68,10 +68,10 @@ class PusherFactory: template_text=self._notif_template_text, ) self.mailers[app_name] = mailer - return EmailPusher(self.hs, pusherdict, mailer) + return EmailPusher(self.hs, pusher_config, mailer) - def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str: - data = pusherdict["data"] + def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str: + data = pusher_config.data if isinstance(data, dict): brand = data.get("brand") diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 9c12d81cfb..8158356d40 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional from prometheus_client import Gauge @@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.push import Pusher, PusherConfigException +from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push.pusher import PusherFactory -from synapse.types import RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute if TYPE_CHECKING: @@ -77,7 +77,7 @@ class PusherPool: # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Pusher]] - def start(self): + def start(self) -> None: """Starts the pushers off in a background process. """ if not self._should_start_pushers: @@ -87,16 +87,16 @@ class PusherPool: async def add_pusher( self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - lang, - data, - profile_tag="", + user_id: str, + access_token: Optional[int], + kind: str, + app_id: str, + app_display_name: str, + device_display_name: str, + pushkey: str, + lang: Optional[str], + data: JsonDict, + profile_tag: str = "", ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -111,21 +111,23 @@ class PusherPool: # recreated, added and started: this means we have only one # code path adding pushers. self.pusher_factory.create_pusher( - { - "id": None, - "user_name": user_id, - "kind": kind, - "app_id": app_id, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "pushkey": pushkey, - "ts": time_now_msec, - "lang": lang, - "data": data, - "last_stream_ordering": None, - "last_success": None, - "failing_since": None, - } + PusherConfig( + id=None, + user_name=user_id, + access_token=access_token, + profile_tag=profile_tag, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + ts=time_now_msec, + lang=lang, + data=data, + last_stream_ordering=None, + last_success=None, + failing_since=None, + ) ) # create the pusher setting last_stream_ordering to the current maximum @@ -151,43 +153,44 @@ class PusherPool: return pusher async def remove_pushers_by_app_id_and_pushkey_not_user( - self, app_id, pushkey, not_user_id - ): + self, app_id: str, pushkey: str, not_user_id: str + ) -> None: to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: - if p["user_name"] != not_user_id: + if p.user_name != not_user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", app_id, pushkey, - p["user_name"], + p.user_name, ) - await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - async def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token( + self, user_id: str, access_tokens: Iterable[int] + ) -> None: """Remove the pushers for a given user corresponding to a set of access_tokens. Args: - user_id (str): user to remove pushers for - access_tokens (Iterable[int]): access token *ids* to remove pushers - for + user_id: user to remove pushers for + access_tokens: access token *ids* to remove pushers for """ if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return tokens = set(access_tokens) for p in await self.store.get_pushers_by_user_id(user_id): - if p["access_token"] in tokens: + if p.access_token in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - p["app_id"], - p["pushkey"], - p["user_name"], + p.app_id, + p.pushkey, + p.user_name, ) - await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - def on_new_notifications(self, max_token: RoomStreamToken): + def on_new_notifications(self, max_token: RoomStreamToken) -> None: if not self.pushers: # nothing to do here. return @@ -206,7 +209,7 @@ class PusherPool: self._on_new_notifications(max_token) @wrap_as_background_process("on_new_notifications") - async def _on_new_notifications(self, max_token: RoomStreamToken): + async def _on_new_notifications(self, max_token: RoomStreamToken) -> None: # We just use the minimum stream ordering and ignore the vector clock # component. This is safe to do as long as we *always* ignore the vector # clock components. @@ -236,7 +239,9 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts( + self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str] + ) -> None: if not self.pushers: # nothing to do here. return @@ -280,14 +285,14 @@ class PusherPool: resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) - pusher_dict = None + pusher_config = None for r in resultlist: - if r["user_name"] == user_id: - pusher_dict = r + if r.user_name == user_id: + pusher_config = r pusher = None - if pusher_dict: - pusher = await self._start_pusher(pusher_dict) + if pusher_config: + pusher = await self._start_pusher(pusher_config) return pusher @@ -302,44 +307,44 @@ class PusherPool: logger.info("Started pushers") - async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]: + async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: """Start the given pusher Args: - pusherdict: dict with the values pulled from the db table + pusher_config: The pusher configuration with the values pulled from the db table Returns: The newly created pusher or None. """ if not self._pusher_shard_config.should_handle( - self._instance_name, pusherdict["user_name"] + self._instance_name, pusher_config.user_name ): return None try: - p = self.pusher_factory.create_pusher(pusherdict) + p = self.pusher_factory.create_pusher(pusher_config) except PusherConfigException as e: logger.warning( "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", - pusherdict["id"], - pusherdict.get("user_name"), - pusherdict.get("app_id"), - pusherdict.get("pushkey"), + pusher_config.id, + pusher_config.user_name, + pusher_config.app_id, + pusher_config.pushkey, e, ) return None except Exception: logger.exception( - "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + "Couldn't start pusher id %i: caught Exception", pusher_config.id, ) return None if not p: return None - appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) + appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey) - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + byuser = self.pushers.setdefault(pusher_config.user_name, {}) if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p @@ -349,8 +354,8 @@ class PusherPool: # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. - user_id = pusherdict["user_name"] - last_stream_ordering = pusherdict["last_stream_ordering"] + user_id = pusher_config.user_name + last_stream_ordering = pusher_config.last_stream_ordering if last_stream_ordering: have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering @@ -364,7 +369,7 @@ class PusherPool: return p - async def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index eb74903d68..0d39a93ed2 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -12,21 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple +from synapse.storage.types import Connection from synapse.storage.util.id_generators import _load_current_id class SlavedIdTracker: - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn: Connection, + table: str, + column: str, + extra_tables: Optional[List[Tuple[str, str]]] = None, + step: int = 1, + ): self.step = step self._current = _load_current_id(db_conn, table, column, step) - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) + if extra_tables: + for table, column in extra_tables: + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name, new_id): + def advance(self, instance_name: Optional[str], new_id: int): self._current = (max if self.step > 0 else min)(self._current, new_id) - def get_current_token(self): + def get_current_token(self) -> int: """ Returns: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index c418730ba8..045bd014da 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -13,26 +13,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.types import Connection from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._pushers_id_gen = SlavedIdTracker( + self._pushers_id_gen = SlavedIdTracker( # type: ignore db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token, rows + ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) + self._pushers_id_gen.advance(instance_name, token) # type: ignore return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 88cba369f5..6658c2da56 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -42,17 +42,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_GET_PUSHERS_ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class UsersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)$") @@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet): pushers = await self.store.get_pushers_by_user_id(user_id) - filtered_pushers = [ - {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS} - for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 8fe83f321a..89823fcc39 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = [ - {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers} diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 43660ec4fb..871fb646a5 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -149,9 +149,6 @@ class DataStore( self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 7997242d90..77ba9d819e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -15,18 +15,32 @@ # limitations under the License. import logging -from typing import Iterable, Iterator, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple from canonicaljson import encode_canonical_json +from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table Drops any rows whose data cannot be decoded @@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore): ) continue - yield r + yield PusherConfig(**r) - async def user_has_pusher(self, user_id): + async def user_has_pusher(self, user_id: str) -> bool: ret = await self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None - def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) + async def get_pushers_by_app_id_and_pushkey( + self, app_id: str, pushkey: str + ) -> Iterator[PusherConfig]: + return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) - def get_pushers_by_user_id(self, user_id): - return self.get_pushers_by({"user_name": user_id}) + async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]: + return await self.get_pushers_by({"user_name": user_id}) - async def get_pushers_by(self, keyvalues): + async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: ret = await self.db_pool.simple_select_list( "pushers", keyvalues, @@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore): ) return self._decode_pushers_rows(ret) - async def get_all_pushers(self): + async def get_all_pushers(self) -> Iterator[PusherConfig]: def get_pushers(txn): txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) @@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore): ) @cached(num_args=1, max_entries=15000) - async def get_if_user_has_pusher(self, user_id): + async def get_if_user_has_pusher(self, user_id: str): # This only exists for the cachedList decorator raise NotImplementedError() @cachedList( cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, ) - async def get_if_users_have_pushers(self, user_ids): + async def get_if_users_have_pushers( + self, user_ids: Iterable[str] + ) -> Dict[str, bool]: rows = await self.db_pool.simple_select_many_batch( table="pushers", column="user_name", @@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) async def update_pusher_failing_since( - self, app_id, pushkey, user_id, failing_since + self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int] ) -> None: await self.db_pool.simple_update( table="pushers", @@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore): desc="update_pusher_failing_since", ) - async def get_throttle_params_by_room(self, pusher_id): + async def get_throttle_params_by_room( + self, pusher_id: str + ) -> Dict[str, ThrottleParams]: res = await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, @@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore): params_by_room = {} for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } + params_by_room[row["room_id"]] = ThrottleParams( + row["last_sent_ts"], row["throttle_ms"], + ) return params_by_room - async def set_throttle_params(self, pusher_id, room_id, params) -> None: + async def set_throttle_params( + self, pusher_id: str, room_id: str, params: ThrottleParams + ) -> None: # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, - params, + {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms}, desc="set_throttle_params", lock=False, ) class PusherStore(PusherWorkerStore): - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() async def add_pusher( self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - lang, - data, - last_stream_ordering, - profile_tag="", + user_id: str, + access_token: Optional[int], + kind: str, + app_id: str, + app_display_name: str, + device_display_name: str, + pushkey: str, + pushkey_ts: int, + lang: Optional[str], + data: Optional[JsonDict], + last_stream_ordering: int, + profile_tag: str = "", ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore): # invalidate, since we the user might not have had a pusher before await self.db_pool.runInteraction( "add_pusher", - self._invalidate_cache_and_stream, + self._invalidate_cache_and_stream, # type: ignore self.get_if_user_has_pusher, (user_id,), ) async def delete_pusher_by_app_id_pushkey_user_id( - self, app_id, pushkey, user_id + self, app_id: str, pushkey: str, user_id: str ) -> None: def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( + self._invalidate_cache_and_stream( # type: ignore txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 02d71302ea..133c0e7a28 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -153,12 +153,12 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) - def get_current_token(self): + def get_current_token(self) -> int: """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. Returns: - int + The maximum stream id. """ with self._lock: if self._unfinished_ids: diff --git a/tests/push/test_email.py b/tests/push/test_email.py index bcdcafa5a9..961bf09de9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -209,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump(10) @@ -220,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One email was attempted to be sent self.assertEqual(len(self.email_attempts), 1) @@ -238,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index cb3245d8cf..60f0820cff 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -144,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump() @@ -155,7 +155,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One push was attempted to be sent -- it'll be the first message self.assertEqual(len(self.push_attempts), 1) @@ -176,8 +176,8 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) - last_stream_ordering = pushers[0]["last_stream_ordering"] + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) + last_stream_ordering = pushers[0].last_stream_ordering # Now it'll try and send the second push message, which will be the second one self.assertEqual(len(self.push_attempts), 2) @@ -198,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) def test_sends_high_priority_for_encrypted(self): """ diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 582f983225..df62317e69 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -766,7 +766,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual("@bob:test", pushers[0]["user_name"]) + self.assertEqual("@bob:test", pushers[0].user_name) @override_config( { -- cgit 1.5.1 From e1b8e37f936b115e2164d272333c9b15342e6f88 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 16 Dec 2020 20:01:53 +0000 Subject: Push login completion down into SsoHandler (#8941) This is another part of my work towards fixing #8876. It moves some of the logic currently in the SAML and OIDC handlers - in particular the call to `AuthHandler.complete_sso_login` down into the `SsoHandler`. --- changelog.d/8941.feature | 1 + synapse/handlers/oidc_handler.py | 62 +++++++++++++++++----------------------- synapse/handlers/saml_handler.py | 37 ++++++++---------------- synapse/handlers/sso.py | 58 +++++++++++++++++++++++-------------- tests/handlers/test_saml.py | 8 +++--- 5 files changed, 80 insertions(+), 86 deletions(-) create mode 100644 changelog.d/8941.feature (limited to 'tests') diff --git a/changelog.d/8941.feature b/changelog.d/8941.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8941.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f626117f76..cbd11a1382 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -115,8 +115,6 @@ class OidcHandler(BaseHandler): self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key @@ -689,33 +687,14 @@ class OidcHandler(BaseHandler): # otherwise, it's a login - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user( - userinfo, token, user_agent, ip_address + await self._complete_oidc_login( + userinfo, token, request, client_redirect_url ) except MappingException as e: logger.exception("Could not map user") self._sso_handler.render_error(request, "mapping_error", str(e)) - return - - # Mapping providers might not have get_extra_attributes: only call this - # method if it exists. - extra_attributes = None - get_extra_attributes = getattr( - self._user_mapping_provider, "get_extra_attributes", None - ) - if get_extra_attributes: - extra_attributes = await get_extra_attributes(userinfo, token) - - # and finally complete the login - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_attributes - ) def _generate_oidc_session_token( self, @@ -838,10 +817,14 @@ class OidcHandler(BaseHandler): now = self.clock.time_msec() return now < expiry - async def _map_userinfo_to_user( - self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str - ) -> str: - """Maps a UserInfo object to a mxid. + async def _complete_oidc_login( + self, + userinfo: UserInfo, + token: Token, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: + """Given a UserInfo response, complete the login flow UserInfo should have a claim that uniquely identifies users. This claim is usually `sub`, but can be configured with `oidc_config.subject_claim`. @@ -853,17 +836,16 @@ class OidcHandler(BaseHandler): If a user already exists with the mxid we've mapped and allow_existing_users is disabled, raise an exception. + Otherwise, render a redirect back to the client_redirect_url with a loginToken. + Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. Raises: MappingException: if there was an error while mapping some properties - - Returns: - The mxid of the user """ try: remote_user_id = self._remote_id_from_userinfo(userinfo) @@ -931,13 +913,23 @@ class OidcHandler(BaseHandler): return None - return await self._sso_handler.get_mxid_from_sso( + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + + await self._sso_handler.complete_sso_login_request( self._auth_provider_id, remote_user_id, - user_agent, - ip_address, + request, + client_redirect_url, oidc_response_to_user_attributes, grandfather_existing_users, + extra_attributes, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 6001fe3e27..5fa7ab3f8b 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -58,8 +58,6 @@ class SamlHandler(BaseHandler): super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2_idp_entityid - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( @@ -229,40 +227,29 @@ class SamlHandler(BaseHandler): ) return - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - # Call the mapper to register/login the user try: - user_id = await self._map_saml_response_to_user( - saml2_auth, relay_state, user_agent, ip_address - ) + await self._complete_saml_login(saml2_auth, request, relay_state) except MappingException as e: logger.exception("Could not map user") self._sso_handler.render_error(request, "mapping_error", str(e)) - return - await self._auth_handler.complete_sso_login(user_id, request, relay_state) - - async def _map_saml_response_to_user( + async def _complete_saml_login( self, saml2_auth: saml2.response.AuthnResponse, + request: SynapseRequest, client_redirect_url: str, - user_agent: str, - ip_address: str, - ) -> str: + ) -> None: """ - Given a SAML response, retrieve the user ID for it and possibly register the user. + Given a SAML response, complete the login flow + + Retrieves the remote user ID, registers the user if necessary, and serves + a redirect back to the client with a login-token. Args: saml2_auth: The parsed SAML2 response. + request: The request to respond to client_redirect_url: The redirect URL passed in by the client. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. - - Returns: - The user ID associated with this response. Raises: MappingException if there was a problem mapping the response to a user. @@ -318,11 +305,11 @@ class SamlHandler(BaseHandler): return None - return await self._sso_handler.get_mxid_from_sso( + await self._sso_handler.complete_sso_login_request( self._auth_provider_id, remote_user_id, - user_agent, - ip_address, + request, + client_redirect_url, saml_response_to_remapped_user_attributes, grandfather_existing_users, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 112a7d5b2c..f054b66a53 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -21,7 +21,8 @@ from twisted.web.http import Request from synapse.api.errors import RedirectException from synapse.http.server import respond_with_html -from synapse.types import UserID, contains_invalid_mxid_characters +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: @@ -119,15 +120,16 @@ class SsoHandler: # No match. return None - async def get_mxid_from_sso( + async def complete_sso_login_request( self, auth_provider_id: str, remote_user_id: str, - user_agent: str, - ip_address: str, + request: SynapseRequest, + client_redirect_url: str, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], - ) -> str: + extra_login_attributes: Optional[JsonDict] = None, + ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -146,12 +148,18 @@ class SsoHandler: given user-agent and IP address and the SSO ID is linked to this matrix ID for subsequent calls. + Finally, we generate a redirect to the supplied redirect uri, with a login token + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". + remote_user_id: The unique identifier from the SSO provider. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + + request: The request to respond to + + client_redirect_url: The redirect URL passed in by the client. + sso_to_matrix_id_mapper: A callable to generate the user attributes. The only parameter is an integer which represents the amount of times the returned mxid localpart mapping has failed. @@ -163,12 +171,13 @@ class SsoHandler: to the user. RedirectException to redirect to an additional page (e.g. to prompt the user for more information). + grandfather_existing_users: A callable which can return an previously existing matrix ID. The SSO ID is then linked to the returned matrix ID. - Returns: - The user ID associated with the SSO response. + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. Raises: MappingException if there was a problem mapping the response to a user. @@ -181,28 +190,33 @@ class SsoHandler: # interstitial pages. with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user - previously_registered_user_id = await self.get_sso_user_by_remote_user_id( + user_id = await self.get_sso_user_by_remote_user_id( auth_provider_id, remote_user_id, ) - if previously_registered_user_id: - return previously_registered_user_id # Check for grandfathering of users. - if grandfather_existing_users: - previously_registered_user_id = await grandfather_existing_users() - if previously_registered_user_id: + if not user_id and grandfather_existing_users: + user_id = await grandfather_existing_users() + if user_id: # Future logins should also match this user ID. await self._store.record_user_external_id( - auth_provider_id, remote_user_id, previously_registered_user_id + auth_provider_id, remote_user_id, user_id ) - return previously_registered_user_id # Otherwise, generate a new user. - attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) - user_id = await self._register_mapped_user( - attributes, auth_provider_id, remote_user_id, user_agent, ip_address, - ) - return user_id + if not user_id: + attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + user_id = await self._register_mapped_user( + attributes, + auth_provider_id, + remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url, extra_login_attributes + ) async def _call_attribute_mapper( self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 69927cf6be..548038214b 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri" + "@test_user:test", request, "redirect_uri", None ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "" + "@test_user:test", request, "", None ) # Subsequent calls should map to the same mxid. @@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "" + "@test_user:test", request, "", None ) def test_map_saml_response_to_invalid_localpart(self): @@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", request, "" + "@test_user1:test", request, "", None ) auth_handler.complete_sso_login.reset_mock() -- cgit 1.5.1 From ff5c4da1289cb5e097902b3e55b771be342c29d6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 17:25:24 -0500 Subject: Add a maximum size for well-known lookups. (#8950) --- changelog.d/8950.misc | 1 + synapse/http/client.py | 32 ++++++++++++---------- synapse/http/federation/well_known_resolver.py | 25 +++++++++++++++-- synapse/http/matrixfederationclient.py | 13 +++++++-- .../federation/test_matrix_federation_agent.py | 27 ++++++++++++++++++ 5 files changed, 80 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8950.misc (limited to 'tests') diff --git a/changelog.d/8950.misc b/changelog.d/8950.misc new file mode 100644 index 0000000000..42e0335afc --- /dev/null +++ b/changelog.d/8950.misc @@ -0,0 +1 @@ +Add a maximum size of 50 kilobytes to .well-known lookups. diff --git a/synapse/http/client.py b/synapse/http/client.py index df7730078f..29f40ddf5f 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -720,11 +720,14 @@ class SimpleHttpClient: try: length = await make_deferred_yieldable( - readBodyToFile(response, output_stream, max_size) + read_body_with_max_size(response, output_stream, max_size) + ) + except BodyExceededMaxSize: + SynapseError( + 502, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, ) - except SynapseError: - # This can happen e.g. because the body is too large. - raise except Exception as e: raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e @@ -748,7 +751,11 @@ def _timeout_to_request_timed_out_error(f: Failure): return f -class _ReadBodyToFileProtocol(protocol.Protocol): +class BodyExceededMaxSize(Exception): + """The maximum allowed size of the HTTP body was exceeded.""" + + +class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): def __init__( self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] ): @@ -761,13 +768,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback( - SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - ) - ) + self.deferred.errback(BodyExceededMaxSize()) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -782,12 +783,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred.errback(reason) -def readBodyToFile( +def read_body_with_max_size( response: IResponse, stream: BinaryIO, max_size: Optional[int] ) -> defer.Deferred: """ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. + If the maximum file size is reached, the returned Deferred will resolve to a + Failure with a BodyExceededMaxSize exception. + Args: response: The HTTP response to read from. stream: The file-object to write to. @@ -798,7 +802,7 @@ def readBodyToFile( """ d = defer.Deferred() - response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) + response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size)) return d diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 5e08ef1664..b3b6dbcab0 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -15,17 +15,19 @@ import logging import random import time +from io import BytesIO from typing import Callable, Dict, Optional, Tuple import attr from twisted.internet import defer from twisted.internet.interfaces import IReactorTime -from twisted.web.client import RedirectAgent, readBody +from twisted.web.client import RedirectAgent from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IResponse +from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder from synapse.util.caches.ttlcache import TTLCache @@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 # lower bound for .well-known cache period WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 +# The maximum size (in bytes) to allow a well-known file to be. +WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB + # Attempt to refetch a cached well-known N% of the TTL before it expires. # e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then # we'll start trying to refetch 1 minute before it expires. @@ -229,6 +234,9 @@ class WellKnownResolver: server_name: name of the server, from the requested url retry: Whether to retry the request if it fails. + Raises: + _FetchWellKnownFailure if we fail to lookup a result + Returns: Returns the response object and body. Response may be a non-200 response. """ @@ -250,7 +258,11 @@ class WellKnownResolver: b"GET", uri, headers=Headers(headers) ) ) - body = await make_deferred_yieldable(readBody(response)) + body_stream = BytesIO() + await make_deferred_yieldable( + read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE) + ) + body = body_stream.getvalue() if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -259,6 +271,15 @@ class WellKnownResolver: except defer.CancelledError: # Bail if we've been cancelled raise + except BodyExceededMaxSize: + # If the well-known file was too large, do not keep attempting + # to download it, but consider it a temporary error. + logger.warning( + "Requested .well-known file for %s is too large > %r bytes", + server_name.decode("ascii"), + WELL_KNOWN_MAX_SIZE, + ) + raise _FetchWellKnownFailure(temporary=True) except Exception as e: if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS: logger.info("Error fetching %s: %s", uri_str, e) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c962994727..b261e078c4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -37,16 +37,19 @@ from twisted.web.iweb import IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, + SynapseError, ) from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, BlacklistingReactorWrapper, + BodyExceededMaxSize, encode_query_args, - readBodyToFile, + read_body_with_max_size, ) from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.logging.context import make_deferred_yieldable @@ -975,9 +978,15 @@ class MatrixFederationHttpClient: headers = dict(response.headers.getAllRawHeaders()) try: - d = readBodyToFile(response, output_stream, max_size) + d = read_body_with_max_size(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) length = await make_deferred_yieldable(d) + except BodyExceededMaxSize: + msg = "Requested file is too large > %r bytes" % (max_size,) + logger.warning( + "{%s} [%s] %s", request.txn_id, request.destination, msg, + ) + SynapseError(502, msg, Codes.TOO_LARGE) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 626acdcaa3..4e51839d0f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -36,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( + WELL_KNOWN_MAX_SIZE, WellKnownResolver, _cache_period_from_headers, ) @@ -1107,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_well_known_too_large(self): + """A well-known query that returns a result which is too large should be rejected.""" + self.reactor.lookups["testserv"] = "1.2.3.4" + + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + response_headers={b"Cache-Control": b"max-age=1000"}, + content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', + ) + + # The result is sucessful, but disabled delegation. + r = self.successResultOf(fetch_d) + self.assertIsNone(r.delegated_server) + def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ -- cgit 1.5.1 From 06006058d7bf6744078109875cd27f47197aeafa Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 17 Dec 2020 11:43:37 +0100 Subject: Make search statement in List Room and User Admin API case-insensitive (#8931) --- changelog.d/8931.feature | 1 + docs/admin_api/user_admin_api.rst | 9 ++- synapse/storage/databases/main/__init__.py | 7 +- synapse/storage/databases/main/room.py | 4 +- tests/rest/admin/test_room.py | 7 ++ tests/rest/admin/test_user.py | 101 ++++++++++++++++++++++++++++- tests/storage/test_main.py | 7 ++ 7 files changed, 125 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8931.feature (limited to 'tests') diff --git a/changelog.d/8931.feature b/changelog.d/8931.feature new file mode 100644 index 0000000000..35c720eb8c --- /dev/null +++ b/changelog.d/8931.feature @@ -0,0 +1 @@ +Make search statement in List Room and List User Admin API case-insensitive. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 1473a3d4e3..e4d6f8203b 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -30,7 +30,12 @@ It returns a JSON body like the following: ], "avatar_url": "", "admin": false, - "deactivated": false + "deactivated": false, + "password_hash": "$2b$12$p9B4GkqYdRTPGD", + "creation_ts": 1560432506, + "appservice_id": null, + "consent_server_notice_sent": null, + "consent_version": null } URL parameters: @@ -139,7 +144,6 @@ A JSON body is returned with the following shape: "users": [ { "name": "", - "password_hash": "", "is_guest": 0, "admin": 0, "user_type": null, @@ -148,7 +152,6 @@ A JSON body is returned with the following shape: "avatar_url": null }, { "name": "", - "password_hash": "", "is_guest": 0, "admin": 1, "user_type": null, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 871fb646a5..701748f93b 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -339,12 +339,13 @@ class DataStore( filters = [] args = [self.hs.config.server_name] + # `name` is in database already in lower case if name: - filters.append("(name LIKE ? OR displayname LIKE ?)") - args.extend(["@%" + name + "%:%", "%" + name + "%"]) + filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") + args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"]) elif user_id: filters.append("name LIKE ?") - args.extend(["%" + user_id + "%"]) + args.extend(["%" + user_id.lower() + "%"]) if not guests: filters.append("is_guest = 0") diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6b89db15c9..4650d0689b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -379,14 +379,14 @@ class RoomWorkerStore(SQLBaseStore): # Filter room names by a string where_statement = "" if search_term: - where_statement = "WHERE state.name LIKE ?" + where_statement = "WHERE LOWER(state.name) LIKE ?" # Our postgres db driver converts ? -> %s in SQL strings as that's the # placeholder for postgres. # HOWEVER, if you put a % into your SQL then everything goes wibbly. # To get around this, we're going to surround search_term with %'s # before giving it to the database in python instead - search_term = "%" + search_term + "%" + search_term = "%" + search_term.lower() + "%" # Set ordering if RoomSortOrder(order_by) == RoomSortOrder.SIZE: diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ca20bcad08..014c30287a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1050,6 +1050,13 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(room_id_2, "else") _search_test(room_id_2, "se") + # Test case insensitive + _search_test(room_id_1, "SOMETHING") + _search_test(room_id_1, "THING") + + _search_test(room_id_2, "ELSE") + _search_test(room_id_2, "SE") + _search_test(None, "foo") _search_test(None, "bar") _search_test(None, "", expected_http_code=400) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index df62317e69..4f379a5e55 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,6 +18,7 @@ import hmac import json import urllib.parse from binascii import unhexlify +from typing import Optional from mock import Mock @@ -466,8 +467,12 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.register_user("user1", "pass1", admin=False) - self.register_user("user2", "pass2", admin=False) + self.user1 = self.register_user( + "user1", "pass1", admin=False, displayname="Name 1" + ) + self.user2 = self.register_user( + "user2", "pass2", admin=False, displayname="Name 2" + ) def test_no_auth(self): """ @@ -476,7 +481,20 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user1", "pass1") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): """ @@ -493,6 +511,83 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) + # Check that all fields are available + for u in channel.json_body["users"]: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def test_search_term(self): + """Test that searching for a users works correctly""" + + def _search_test( + expected_user_id: Optional[str], + search_term: str, + search_field: Optional[str] = "name", + expected_http_code: Optional[int] = 200, + ): + """Search for a user and check that the returned user's id is a match + + Args: + expected_user_id: The user_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for user names with + search_field: Field which is to request: `name` or `user_id` + expected_http_code: The expected http code for the request + """ + url = self.url + "?%s=%s" % (search_field, search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that users were returned + self.assertTrue("users" in channel.json_body) + users = channel.json_body["users"] + + # Check that the expected number of users were returned + expected_user_count = 1 if expected_user_id else 0 + self.assertEqual(len(users), expected_user_count) + self.assertEqual(channel.json_body["total"], expected_user_count) + + if expected_user_id: + # Check that the first returned user id is correct + u = users[0] + self.assertEqual(expected_user_id, u["name"]) + + # Perform search tests + _search_test(self.user1, "er1") + _search_test(self.user1, "me 1") + + _search_test(self.user2, "er2") + _search_test(self.user2, "me 2") + + _search_test(self.user1, "er1", "user_id") + _search_test(self.user2, "er2", "user_id") + + # Test case insensitive + _search_test(self.user1, "ER1") + _search_test(self.user1, "NAME 1") + + _search_test(self.user2, "ER2") + _search_test(self.user2, "NAME 2") + + _search_test(self.user1, "ER1", "user_id") + _search_test(self.user2, "ER2", "user_id") + + _search_test(None, "foo") + _search_test(None, "bar") + + _search_test(None, "foo", "user_id") + _search_test(None, "bar", "user_id") + class UserRestTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 7e7f1286d9..e9e3bca3bf 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase): self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) + + users, total = yield defer.ensureDeferred( + self.store.get_users_paginate(0, 10, name="BC", guests=False) + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) -- cgit 1.5.1 From c07022303ef596fe7f42f6eb7001660a62801715 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 17 Dec 2020 13:05:39 +0100 Subject: Fix a bug that deactivated users appear in the directory (#8933) Fixes a bug that deactivated users appear in the directory when their profile information was updated. To change profile information of deactivated users is neccesary for example you will remove displayname or avatar. But they should not appear in directory. They are deactivated. Co-authored-by: Erik Johnston --- changelog.d/8933.bugfix | 1 + synapse/handlers/user_directory.py | 8 ++++-- tests/handlers/test_user_directory.py | 40 +++++++++++++++++++++++++++- tests/rest/admin/test_user.py | 50 ++++++++++++++++++++++++++++++++++- 4 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8933.bugfix (limited to 'tests') diff --git a/changelog.d/8933.bugfix b/changelog.d/8933.bugfix new file mode 100644 index 0000000000..295933d6cd --- /dev/null +++ b/changelog.d/8933.bugfix @@ -0,0 +1 @@ +Fix a bug where deactivated users appeared in the user directory when their profile information was updated. diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 3d80371f06..7c4eeaaa5e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -113,9 +113,13 @@ class UserDirectoryHandler(StateDeltasHandler): """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - is_support = await self.store.is_support_user(user_id) + # Support users are for diagnostics and should not appear in the user directory. - if not is_support: + is_support = await self.store.is_support_user(user_id) + # When change profile information of deactivated user it should not appear in the user directory. + is_deactivated = await self.store.get_user_deactivated_status(user_id) + + if not (is_support or is_deactivated): await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 1260721dbf..9c886d671a 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) ) + regular_user_id = "@regular:test" + self.get_success( + self.store.register_user(user_id=regular_user_id, password_hash=None) + ) self.get_success( self.handler.handle_local_profile_change(support_user_id, None) @@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): display_name = "display_name" profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) - regular_user_id = "@regular:test" self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) self.assertTrue(profile["display_name"] == display_name) + def test_handle_local_profile_change_with_deactivated_user(self): + # create user + r_user_id = "@regular:test" + self.get_success( + self.store.register_user(user_id=r_user_id, password_hash=None) + ) + + # update profile + display_name = "Regular User" + profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) + self.get_success( + self.handler.handle_local_profile_change(r_user_id, profile_info) + ) + + # profile is in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile["display_name"] == display_name) + + # deactivate user + self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) + self.get_success(self.handler.handle_user_deactivated(r_user_id)) + + # profile is not in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile is None) + + # update profile after deactivation + self.get_success( + self.handler.handle_local_profile_change(r_user_id, profile_info) + ) + + # profile is furthermore not in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile is None) + def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" self.get_success( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4f379a5e55..9d6ef02511 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -603,7 +603,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.other_user = self.register_user("user", "pass") + self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.login("user", "pass") self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user @@ -1012,6 +1012,54 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) + def test_change_name_deactivate_user_user_directory(self): + """ + Test change profile information of a deactivated user and + check that it does not appear in user directory + """ + + # is in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile["display_name"] == "User") + + # Deactivate user + body = json.dumps({"deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + # is not in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile is None) + + # Set new displayname user + body = json.dumps({"displayname": "Foobar"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual("Foobar", channel.json_body["displayname"]) + + # is not in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile is None) + def test_reactivate_user(self): """ Test reactivating another user. -- cgit 1.5.1 From 4c33796b20f934a43f4f09a2bac6653c18d72b69 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 17 Dec 2020 12:55:21 +0000 Subject: Correctly handle AS registerations and add test --- synapse/handlers/auth.py | 8 +++++- synapse/handlers/register.py | 7 ++++- synapse/replication/http/login.py | 12 +++++++-- synapse/rest/client/v2_alpha/register.py | 14 +++++++--- tests/test_mau.py | 45 ++++++++++++++++++++++++++++++-- 5 files changed, 77 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c7dc07008a..3b8ac4325b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -709,6 +709,7 @@ class AuthHandler(BaseHandler): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + is_appservice_ghost: bool = False, ) -> str: """ Creates a new access token for the user with the given user ID. @@ -725,6 +726,7 @@ class AuthHandler(BaseHandler): we should always have a device ID) valid_until_ms: when the token is valid until. None for no expiry. + is_appservice_ghost: Whether the user is an application ghost user Returns: The access token for the user's session. Raises: @@ -745,7 +747,11 @@ class AuthHandler(BaseHandler): "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry ) - await self.auth.check_auth_blocking(user_id) + if ( + not is_appservice_ghost + or self.hs.config.appservice.track_appservice_user_ips + ): + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) await self.store.add_access_token_to_user( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0d85fd0868..039aff1061 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler): device_id: Optional[str], initial_display_name: Optional[str], is_guest: bool = False, + is_appservice_ghost: bool = False, ) -> Tuple[str, str]: """Register a device for a user and generate an access token. @@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler): device_id=device_id, initial_display_name=initial_display_name, is_guest=is_guest, + is_appservice_ghost=is_appservice_ghost, ) return r["device_id"], r["access_token"] @@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler): ) else: access_token = await self._auth_handler.get_access_token_for_user_id( - user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms + user_id, + device_id=registered_device_id, + valid_until_ms=valid_until_ms, + is_appservice_ghost=is_appservice_ghost, ) return (registered_device_id, access_token) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 4c81e2d784..36071feb36 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): + async def _serialize_payload( + user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + ): """ Args: device_id (str|None): Device ID to use, if None a new one is @@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, + "is_appservice_ghost": is_appservice_ghost, } async def _handle_request(self, request, user_id): @@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] + is_appservice_ghost = content["is_appservice_ghost"] device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost=is_appservice_ghost, ) return 200, {"device_id": device_id, "access_token": access_token} diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a89ae6ddf9..722d993811 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet): user_id = await self.registration_handler.appservice_register( username, as_token ) - return await self._create_registration_details(user_id, body) + return await self._create_registration_details( + user_id, body, is_appservice_ghost=True, + ) - async def _create_registration_details(self, user_id, params): + async def _create_registration_details( + self, user_id, params, is_appservice_ghost=False + ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False + user_id, + device_id, + initial_display_name, + is_guest=False, + is_appservice_ghost=is_appservice_ghost, ) result.update({"access_token": access_token, "device_id": device_id}) diff --git a/tests/test_mau.py b/tests/test_mau.py index c5ec6396a7..26548b4611 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -19,6 +19,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha import register, sync from tests import unittest @@ -75,6 +76,44 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_as_ignores_mau(self): + """Test that application services can still create users when the MAU + limit has been reached. + """ + + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + + # We've created and activated two users, we shouldn't be able to + # register new users + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit3") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + # Cheekily add an application service that we use to register a new user + # with. + as_token = "foobartoken" + self.store.services_cache.append( + ApplicationService( + token=as_token, + hostname=self.hs.hostname, + id="SomeASID", + sender="@as_sender:test", + namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, + ) + ) + + self.create_user("as_kermit4", token=as_token) + def test_allowed_after_a_month_mau(self): # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") @@ -192,7 +231,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) - def create_user(self, localpart): + def create_user(self, localpart, token=None): request_data = json.dumps( { "username": localpart, @@ -201,7 +240,9 @@ class TestMauLimit(unittest.HomeserverTestCase): } ) - request, channel = self.make_request("POST", "/register", request_data) + request, channel = self.make_request( + "POST", "/register", request_data, access_token=token + ) if channel.code != 200: raise HttpResponseException( -- cgit 1.5.1 From f2783fc201edaa49eafd8be06f8cda16ec1f3d95 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 17 Dec 2020 14:42:30 +0100 Subject: Use the simple dictionary in full text search for the user directory (#8959) * Use the simple dictionary in fts for the user directory * Clarify naming --- changelog.d/8959.bugfix | 1 + synapse/storage/databases/main/user_directory.py | 24 ++++++++++++------------ tests/storage/test_user_directory.py | 23 +++++++++++++++++++++++ 3 files changed, 36 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8959.bugfix (limited to 'tests') diff --git a/changelog.d/8959.bugfix b/changelog.d/8959.bugfix new file mode 100644 index 0000000000..772818bae9 --- /dev/null +++ b/changelog.d/8959.bugfix @@ -0,0 +1 @@ +Fix a bug causing common English words to not be considered for a user directory search. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index fc8caf46a0..ef11f1c3b3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -396,9 +396,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( @@ -418,9 +418,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') ) """ txn.execute( @@ -435,9 +435,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): elif new_entry is False: sql = """ UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + SET vector = setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') WHERE user_id = ? """ txn.execute( @@ -764,7 +764,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): INNER JOIN user_directory AS d USING (user_id) WHERE %s - AND vector @@ to_tsquery('english', ?) + AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) @@ -773,13 +773,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 3 * ts_rank_cd( '{0.1, 0.1, 0.9, 1.0}', vector, - to_tsquery('english', ?), + to_tsquery('simple', ?), 8 ) + ts_rank_cd( '{0.1, 0.1, 0.9, 1.0}', vector, - to_tsquery('english', ?), + to_tsquery('simple', ?), 8 ) ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 738e912468..a6f63f4aaf 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" +# The localpart isn't 'Bela' on purpose so we can test looking up display names. +BELA = "@somenickname:a" class UserDirectoryStoreTestCase(unittest.TestCase): @@ -40,6 +42,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase): yield defer.ensureDeferred( self.store.update_profile_in_user_dir(BOBBY, "bobby", None) ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BELA, "Bela", None) + ) yield defer.ensureDeferred( self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) ) @@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase): ) finally: self.hs.config.user_directory_search_all_users = False + + @defer.inlineCallbacks + def test_search_user_dir_stop_words(self): + """Tests that a user can look up another user by searching for the start if its + display name even if that name happens to be a common English word that would + usually be ignored in full text searches. + """ + self.hs.config.user_directory_search_all_users = True + try: + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) + finally: + self.hs.config.user_directory_search_all_users = False -- cgit 1.5.1 From c9c1c9d82f190abc2f1254f75fe42bf29eff08e1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 17 Dec 2020 15:46:40 +0000 Subject: Fix `UsersListTestCase` (#8964) --- changelog.d/8964.bugfix | 1 + tests/rest/admin/test_user.py | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8964.bugfix (limited to 'tests') diff --git a/changelog.d/8964.bugfix b/changelog.d/8964.bugfix new file mode 100644 index 0000000000..295933d6cd --- /dev/null +++ b/changelog.d/8964.bugfix @@ -0,0 +1 @@ +Fix a bug where deactivated users appeared in the user directory when their profile information was updated. diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9d6ef02511..9b2e4765f6 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -489,9 +489,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user1", "pass1") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -540,7 +538,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_http_code: The expected http code for the request """ url = self.url + "?%s=%s" % (search_field, search_term,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -1026,7 +1024,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Deactivate user body = json.dumps({"deactivated": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -1044,7 +1042,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set new displayname user body = json.dumps({"displayname": "Foobar"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, -- cgit 1.5.1 From 14eab1b4d2cf837c6b0925da7194cc5940e9401c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 17 Dec 2020 16:14:13 +0000 Subject: Update tests/test_mau.py Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- tests/test_mau.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_mau.py b/tests/test_mau.py index 26548b4611..586294446b 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -78,7 +78,8 @@ class TestMauLimit(unittest.HomeserverTestCase): def test_as_ignores_mau(self): """Test that application services can still create users when the MAU - limit has been reached. + limit has been reached. This only works when application service + user ip tracking is disabled. """ # Create and sync so that the MAU counts get updated -- cgit 1.5.1 From 70586aa63eaf129505324976fb092cb3ad327590 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Dec 2020 09:49:18 +0000 Subject: Try and drop stale extremities. (#8929) If we see stale extremities while persisting events, and notice that they don't change the result of state resolution, we drop them. --- changelog.d/8929.misc | 1 + synapse/api/constants.py | 2 + synapse/handlers/message.py | 2 +- synapse/storage/persist_events.py | 200 +++++++++++++++++++++-- synapse/visibility.py | 2 +- tests/storage/test_events.py | 334 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 523 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8929.misc create mode 100644 tests/storage/test_events.py (limited to 'tests') diff --git a/changelog.d/8929.misc b/changelog.d/8929.misc new file mode 100644 index 0000000000..157018b6a6 --- /dev/null +++ b/changelog.d/8929.misc @@ -0,0 +1 @@ +Automatically drop stale forward-extremities under some specific conditions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 1932df83b4..565a8cd76a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -95,6 +95,8 @@ class EventTypes: Presence = "m.presence" + Dummy = "org.matrix.dummy_event" + class RejectedReason: AUTH_ERROR = "auth_error" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2b8aa9443d..9dfeab09cd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1261,7 +1261,7 @@ class EventCreationHandler: event, context = await self.create_event( requester, { - "type": "org.matrix.dummy_event", + "type": EventTypes.Dummy, "content": {}, "room_id": room_id, "sender": user_id, diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 70e636b0ba..61fc49c69c 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + Collection, + PersistedEventPosition, + RoomStreamToken, + StateMap, + get_domain_from_id, +) from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram( buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) +state_resolutions_during_persistence = Counter( + "synapse_storage_events_state_resolutions_during_persistence", + "Number of times we had to do state res to calculate new current state", +) + +potential_times_prune_extremities = Counter( + "synapse_storage_events_potential_times_prune_extremities", + "Number of times we might be able to prune extremities", +) + +times_pruned_extremities = Counter( + "synapse_storage_events_times_pruned_extremities", + "Number of times we were actually be able to prune extremities", +) + class _EventPeristenceQueue: """Queues up events so that they can be persisted in bulk with only one @@ -454,7 +476,15 @@ class EventsPersistenceStorage: latest_event_ids, new_latest_event_ids, ) - current_state, delta_ids = res + current_state, delta_ids, new_latest_event_ids = res + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids # If either are not None then there has been a change, # and we need to work out the delta (or use that @@ -573,29 +603,35 @@ class EventsPersistenceStorage: self, room_id: str, events_context: List[Tuple[EventBase, EventContext]], - old_latest_event_ids: Iterable[str], - new_latest_event_ids: Iterable[str], - ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: + old_latest_event_ids: Set[str], + new_latest_event_ids: Set[str], + ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: """Calculate the current state dict after adding some new events to a room Args: - room_id (str): + room_id: room to which the events are being added. Used for logging etc - events_context (list[(EventBase, EventContext)]): + events_context: events and contexts which are being added to the room - old_latest_event_ids (iterable[str]): + old_latest_event_ids: the old forward extremities for the room. - new_latest_event_ids (iterable[str]): + new_latest_event_ids : the new forward extremities for the room. Returns: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. + Returns a tuple of two state maps and a set of new forward + extremities. + + The first state map is the full new current state and the second + is the delta to the existing current state. If both are None then + there has been no change. + + The function may prune some old entries from the set of new + forward extremities if it's safe to do so. If there has been a change then we only return the delta if its already been calculated. Conversely if we do know the delta then @@ -672,7 +708,7 @@ class EventsPersistenceStorage: # If they old and new groups are the same then we don't need to do # anything. if old_state_groups == new_state_groups: - return None, None + return None, None, new_latest_event_ids if len(new_state_groups) == 1 and len(old_state_groups) == 1: # If we're going from one state group to another, lets check if @@ -689,7 +725,7 @@ class EventsPersistenceStorage: # the current state in memory then lets also return that, # but it doesn't matter if we don't. new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids + return new_state, delta_ids, new_latest_event_ids # Now that we have calculated new_state_groups we need to get # their state IDs so we can resolve to a single state set. @@ -701,7 +737,7 @@ class EventsPersistenceStorage: if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - return state_groups_map[new_state_groups.pop()], None + return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids # Ok, we need to defer to the state handler to resolve our state sets. @@ -734,7 +770,139 @@ class EventsPersistenceStorage: state_res_store=StateResolutionStore(self.main_store), ) - return res.state, None + state_resolutions_during_persistence.inc() + + # If the returned state matches the state group of one of the new + # forward extremities then we check if we are able to prune some state + # extremities. + if res.state_group and res.state_group in new_state_groups: + new_latest_event_ids = await self._prune_extremities( + room_id, + new_latest_event_ids, + res.state_group, + event_id_to_state_group, + events_context, + ) + + return res.state, None, new_latest_event_ids + + async def _prune_extremities( + self, + room_id: str, + new_latest_event_ids: Set[str], + resolved_state_group: int, + event_id_to_state_group: Dict[str, int], + events_context: List[Tuple[EventBase, EventContext]], + ) -> Set[str]: + """See if we can prune any of the extremities after calculating the + resolved state. + """ + potential_times_prune_extremities.inc() + + # We keep all the extremities that have the same state group, and + # see if we can drop the others. + new_new_extrems = { + e + for e in new_latest_event_ids + if event_id_to_state_group[e] == resolved_state_group + } + + dropped_extrems = set(new_latest_event_ids) - new_new_extrems + + logger.debug("Might drop extremities: %s", dropped_extrems) + + # We only drop events from the extremities list if: + # 1. we're not currently persisting them; + # 2. they're not our own events (or are dummy events); and + # 3. they're either: + # 1. over N hours old and more than N events ago (we use depth to + # calculate); or + # 2. we are persisting an event from the same domain and more than + # M events ago. + # + # The idea is that we don't want to drop events that are "legitimate" + # extremities (that we would want to include as prev events), only + # "stuck" extremities that are e.g. due to a gap in the graph. + # + # Note that we either drop all of them or none of them. If we only drop + # some of the events we don't know if state res would come to the same + # conclusion. + + for ev, _ in events_context: + if ev.event_id in dropped_extrems: + logger.debug( + "Not dropping extremities: %s is being persisted", ev.event_id + ) + return new_latest_event_ids + + dropped_events = await self.main_store.get_events( + dropped_extrems, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + + new_senders = {get_domain_from_id(e.sender) for e, _ in events_context} + + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + current_depth = max(e.depth for e, _ in events_context) + for event in dropped_events.values(): + # If the event is a local dummy event then we should check it + # doesn't reference any local events, as we want to reference those + # if we send any new events. + # + # Note we do this recursively to handle the case where a dummy event + # references a dummy event that only references remote events. + # + # Ideally we'd figure out a way of still being able to drop old + # dummy events that reference local events, but this is good enough + # as a first cut. + events_to_check = [event] + while events_to_check: + new_events = set() + for event_to_check in events_to_check: + if self.is_mine_id(event_to_check.sender): + if event_to_check.type != EventTypes.Dummy: + logger.debug("Not dropping own event") + return new_latest_event_ids + new_events.update(event_to_check.prev_event_ids()) + + prev_events = await self.main_store.get_events( + new_events, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + events_to_check = prev_events.values() + + if ( + event.origin_server_ts < one_day_ago + and event.depth < current_depth - 100 + ): + continue + + # We can be less conservative about dropping extremities from the + # same domain, though we do want to wait a little bit (otherwise + # we'll immediately remove all extremities from a given server). + if ( + get_domain_from_id(event.sender) in new_senders + and event.depth < current_depth - 20 + ): + continue + + logger.debug( + "Not dropping as too new and not in new_senders: %s", new_senders, + ) + + return new_latest_event_ids + + times_pruned_extremities.inc() + + logger.info( + "Pruning forward extremities in room %s: from %s -> %s", + room_id, + new_latest_event_ids, + new_new_extrems, + ) + return new_new_extrems async def _calculate_state_delta( self, room_id: str, current_state: StateMap[str] diff --git a/synapse/visibility.py b/synapse/visibility.py index f2836ba9f0..ec50e7e977 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -125,7 +125,7 @@ async def filter_events_for_client( # see events in the room at that point in the DAG, and that shouldn't be decided # on those checks. if filter_send_to_client: - if event.type == "org.matrix.dummy_event": + if event.type == EventTypes.Dummy: return None if not event.is_state() and event.sender in ignore_list: diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..71210ce606 --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions +from synapse.federation.federation_base import event_from_pdu_json +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.unittest import HomeserverTestCase + + +class ExtremPruneTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastore() + + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=self.token + ) + + body = self.helper.send(self.room_id, body="Test", tok=self.token) + local_message_event_id = body["event_id"] + + # Fudge a remote event and persist it. This will be the extremity before + # the gap. + self.remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Message, + "state_key": "@user:other", + "content": {}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + self.persist_event(self.remote_event_1) + + # Check that the current extremities is the remote event. + self.assert_extremities([self.remote_event_1.event_id]) + + def persist_event(self, event, state=None): + """Persist the event, with optional state + """ + context = self.get_success( + self.state.compute_event_context(event, old_state=state) + ) + self.get_success(self.persistence.persist_event(event, context)) + + def assert_extremities(self, expected_extremities): + """Assert the current extremities for the room + """ + extremities = self.get_success( + self.store.get_prev_events_for_room(self.room_id) + ) + self.assertCountEqual(extremities, expected_extremities) + + def test_prune_gap(self): + """Test that we drop extremities after a gap when we see an event from + the same domain. + """ + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 50, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_do_not_prune_gap_if_state_different(self): + """Test that we don't prune extremities after a gap if the resolved + state is different. + """ + + # Fudge a second event which points to an event we don't have. + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Message, + "state_key": "@user:other", + "content": {}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 10, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + # Now we persist it with state with a dropped history visibility + # setting. The state resolution across the old and new event will then + # include it, and so the resolved state won't match the new state. + state_before_gap = dict( + self.get_success(self.state.get_current_state(self.room_id)) + ) + state_before_gap.pop(("m.room.history_visibility", "")) + + context = self.get_success( + self.state.compute_event_context( + remote_event_2, old_state=state_before_gap.values() + ) + ) + + self.get_success(self.persistence.persist_event(remote_event_2, context)) + + # Check that we haven't dropped the old extremity. + self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) + + def test_prune_gap_if_old(self): + """Test that we drop extremities after a gap when the previous extremity + is "old" + """ + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_do_not_prune_gap_if_other_server(self): + """Test that we do not drop extremities after a gap when we see an event + from a different domain. + """ + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) + + def test_prune_gap_if_dummy_remote(self): + """Test that we drop extremities after a gap when the previous extremity + is a local dummy event and only points to remote events. + """ + + body = self.helper.send_event( + self.room_id, type=EventTypes.Dummy, content={}, tok=self.token + ) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_prune_gap_if_dummy_local(self): + """Test that we don't drop extremities after a gap when the previous + extremity is a local dummy event and points to local events. + """ + + body = self.helper.send(self.room_id, body="Test", tok=self.token) + + body = self.helper.send_event( + self.room_id, type=EventTypes.Dummy, content={}, tok=self.token + ) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id, local_message_event_id]) + + def test_do_not_prune_gap_if_not_dummy(self): + """Test that we do not drop extremities after a gap when the previous extremity + is not a dummy event. + """ + + body = self.helper.send(self.room_id, body="test", tok=self.token) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([local_message_event_id, remote_event_2.event_id]) -- cgit 1.5.1 From 5d4c330ed979b0d60efe5f80fd76de8f162263a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 07:33:57 -0500 Subject: Allow re-using a UI auth validation for a period of time (#8970) --- changelog.d/8970.feature | 1 + docs/sample_config.yaml | 15 +++ synapse/config/_base.pyi | 4 +- synapse/config/auth.py | 110 +++++++++++++++++++++ synapse/config/homeserver.py | 4 +- synapse/config/password.py | 90 ----------------- synapse/handlers/auth.py | 32 ++++-- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/storage/databases/main/registration.py | 38 +++++++ .../delta/58/26access_token_last_validated.sql | 18 ++++ tests/rest/client/v2_alpha/test_auth.py | 94 ++++++++++++------ 11 files changed, 280 insertions(+), 136 deletions(-) create mode 100644 changelog.d/8970.feature create mode 100644 synapse/config/auth.py delete mode 100644 synapse/config/password.py create mode 100644 synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql (limited to 'tests') diff --git a/changelog.d/8970.feature b/changelog.d/8970.feature new file mode 100644 index 0000000000..6d5b3303a6 --- /dev/null +++ b/changelog.d/8970.feature @@ -0,0 +1 @@ +Allow re-using an user-interactive authentication session for a period of time. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 75a01094d5..549c581a97 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2068,6 +2068,21 @@ password_config: # #require_uppercase: true +ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + # Configuration for sending emails from Synapse. # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index ed26e2fb60..29aa064e57 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -3,6 +3,7 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( api, appservice, + auth, captcha, cas, consent_config, @@ -14,7 +15,6 @@ from synapse.config import ( logger, metrics, oidc_config, - password, password_auth_providers, push, ratelimiting, @@ -65,7 +65,7 @@ class RootConfig: sso: sso.SSOConfig oidc: oidc_config.OIDCConfig jwt: jwt_config.JWTConfig - password: password.PasswordConfig + auth: auth.AuthConfig email: emailconfig.EmailConfig worker: workers.WorkerConfig authproviders: password_auth_providers.PasswordAuthProviderConfig diff --git a/synapse/config/auth.py b/synapse/config/auth.py new file mode 100644 index 0000000000..2b3e2ce87b --- /dev/null +++ b/synapse/config/auth.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class AuthConfig(Config): + """Password and login configuration + """ + + section = "auth" + + def read_config(self, config, **kwargs): + password_config = config.get("password_config", {}) + if password_config is None: + password_config = {} + + self.password_enabled = password_config.get("enabled", True) + self.password_localdb_enabled = password_config.get("localdb_enabled", True) + self.password_pepper = password_config.get("pepper", "") + + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + + # User-interactive authentication + ui_auth = config.get("ui_auth") or {} + self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + password_config: + # Uncomment to disable password login + # + #enabled: false + + # Uncomment to disable authentication against the local password + # database. This is ignored if `enabled` is false, and is only useful + # if you have other password_providers. + # + #localdb_enabled: false + + # Uncomment and change to a secret random string for extra security. + # DO NOT CHANGE THIS AFTER INITIAL SETUP! + # + #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + + ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index be65554524..4bd2b3587b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .auth import AuthConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -30,7 +31,6 @@ from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig from .oidc_config import OIDCConfig -from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig @@ -76,7 +76,7 @@ class HomeServerConfig(RootConfig): CasConfig, SSOConfig, JWTConfig, - PasswordConfig, + AuthConfig, EmailConfig, PasswordAuthProviderConfig, PushConfig, diff --git a/synapse/config/password.py b/synapse/config/password.py deleted file mode 100644 index 9c0ea8c30a..0000000000 --- a/synapse/config/password.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._base import Config - - -class PasswordConfig(Config): - """Password login configuration - """ - - section = "password" - - def read_config(self, config, **kwargs): - password_config = config.get("password_config", {}) - if password_config is None: - password_config = {} - - self.password_enabled = password_config.get("enabled", True) - self.password_localdb_enabled = password_config.get("localdb_enabled", True) - self.password_pepper = password_config.get("pepper", "") - - # Password policy - self.password_policy = password_config.get("policy") or {} - self.password_policy_enabled = self.password_policy.get("enabled", False) - - def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """\ - password_config: - # Uncomment to disable password login - # - #enabled: false - - # Uncomment to disable authentication against the local password - # database. This is ignored if `enabled` is false, and is only useful - # if you have other password_providers. - # - #localdb_enabled: false - - # Uncomment and change to a secret random string for extra security. - # DO NOT CHANGE THIS AFTER INITIAL SETUP! - # - #pepper: "EVEN_MORE_SECRET" - - # Define and enforce a password policy. Each parameter is optional. - # This is an implementation of MSC2000. - # - policy: - # Whether to enforce the password policy. - # Defaults to 'false'. - # - #enabled: true - - # Minimum accepted length for a password. - # Defaults to 0. - # - #minimum_length: 15 - - # Whether a password must contain at least one digit. - # Defaults to 'false'. - # - #require_digit: true - - # Whether a password must contain at least one symbol. - # A symbol is any character that's not a number or a letter. - # Defaults to 'false'. - # - #require_symbol: true - - # Whether a password must contain at least one lowercase letter. - # Defaults to 'false'. - # - #require_lowercase: true - - # Whether a password must contain at least one lowercase letter. - # Defaults to 'false'. - # - #require_uppercase: true - """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 57ff461f92..f4434673dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -226,6 +226,9 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # The number of seconds to keep a UI auth session active. + self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout + # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -283,7 +286,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> Tuple[dict, str]: + ) -> Tuple[dict, Optional[str]]: """ Checks that the user is who they claim to be, via a UI auth. @@ -310,7 +313,8 @@ class AuthHandler(BaseHandler): have been given only in a previous call). 'session_id' is the ID of this session, either passed in by the - client or assigned by this call + client or assigned by this call. This is None if UI auth was + skipped (by re-using a previous validation). Raises: InteractiveAuthIncompleteError if the client has not yet completed @@ -324,6 +328,16 @@ class AuthHandler(BaseHandler): """ + if self._ui_auth_session_timeout: + last_validated = await self.store.get_access_token_last_validated( + requester.access_token_id + ) + if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout: + # Return the input parameters, minus the auth key, which matches + # the logic in check_ui_auth. + request_body.pop("auth", None) + return request_body, None + user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts @@ -359,6 +373,9 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") + # Note that the access token has been validated. + await self.store.update_access_token_last_validated(requester.access_token_id) + return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: @@ -452,13 +469,10 @@ class AuthHandler(BaseHandler): all the stages in any of the permitted flows. """ - authdict = None sid = None # type: Optional[str] - if clientdict and "auth" in clientdict: - authdict = clientdict["auth"] - del clientdict["auth"] - if "session" in authdict: - sid = authdict["session"] + authdict = clientdict.pop("auth", {}) + if "session" in authdict: + sid = authdict["session"] # Convert the URI and method to strings. uri = request.uri.decode("utf-8") @@ -563,6 +577,8 @@ class AuthHandler(BaseHandler): creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: + # If all the required credentials have been supplied, the user has + # successfully completed the UI auth process! if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can # include the password in the case of registering, so only log diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eebee44a44..d837bde1d6 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -254,14 +254,18 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # If we have a password in this request, prefer it. Otherwise, there - # must be a password hash from an earlier request. + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. if new_password: password_hash = await self.auth_handler.hash(new_password) - else: + elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, "password_hash", None ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ff96c34c2e..8d05288ed4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -943,6 +943,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="del_user_pending_deactivation", ) + async def get_access_token_last_validated(self, token_id: int) -> int: + """Retrieves the time (in milliseconds) of the last validation of an access token. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if the access token was not found. + + Returns: + The last validation time. + """ + result = await self.db_pool.simple_select_one_onecol( + "access_tokens", {"id": token_id}, "last_validated" + ) + + # If this token has not been validated (since starting to track this), + # return 0 instead of None. + return result or 0 + + async def update_access_token_last_validated(self, token_id: int) -> None: + """Updates the last time an access token was validated. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + now = self._clock.time_msec() + + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"last_validated": now}, + desc="update_access_token_last_validated", + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -1150,6 +1186,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): The token ID """ next_id = self._access_tokens_id_gen.get_next() + now = self._clock.time_msec() await self.db_pool.simple_insert( "access_tokens", @@ -1160,6 +1197,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, + "last_validated": now, }, desc="add_access_token_to_user", ) diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql new file mode 100644 index 0000000000..1a101cd5eb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The last time this access token was "validated" (i.e. logged in or succeeded +-- at user-interactive authentication). +ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT; diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 51323b3da3..ac66a4e0b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union from twisted.internet.defer import succeed @@ -177,13 +177,8 @@ class UIAuthTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) - self.user_tok = self.login("test", self.user_pass) - - def get_device_ids(self, access_token: str) -> List[str]: - # Get the list of devices so one can be deleted. - channel = self.make_request("GET", "devices", access_token=access_token,) - self.assertEqual(channel.code, 200) - return [d["device_id"] for d in channel.json_body["devices"]] + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) def delete_device( self, @@ -219,11 +214,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids(self.user_tok)[0] - # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -233,7 +226,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -252,14 +245,13 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - device_id = self.get_device_ids(self.user_tok)[0] - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -282,14 +274,11 @@ class UIAuthTests(unittest.HomeserverTestCase): session ID should be rejected. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + channel = self.delete_devices(401, {"devices": [self.device_id]}) # Grab the session session = channel.json_body["session"] @@ -301,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_devices( 200, { - "devices": [device_ids[1]], + "devices": ["dev2"], "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": self.user}, @@ -316,14 +305,11 @@ class UIAuthTests(unittest.HomeserverTestCase): The initial requested URI cannot be modified during the user interactive authentication session. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -332,9 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. self.delete_device( self.user_tok, - device_ids[1], + "dev2", 403, { "auth": { @@ -346,6 +334,52 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + def test_does_not_offer_password_for_sso_user(self): login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] @@ -361,8 +395,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def test_does_not_offer_sso_for_password_user(self): # now call the device deletion API: we should get the option to auth with SSO # and not password. - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -373,8 +406,7 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] # we have no particular expectations of ordering here -- cgit 1.5.1 From 28877fade90a5cfb3457c9e6c70924dbbe8af715 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 18 Dec 2020 14:19:46 +0000 Subject: Implement a username picker for synapse (#8942) The final part (for now) of my work to implement a username picker in synapse itself. The idea is that we allow `UsernameMappingProvider`s to return `localpart=None`, in which case, rather than redirecting the browser back to the client, we redirect to a username-picker resource, which allows the user to enter a username. We *then* complete the SSO flow (including doing the client permission checks). The static resources for the username picker itself (in https://github.com/matrix-org/synapse/tree/rav/username_picker/synapse/res/username_picker) are essentially lifted wholesale from https://github.com/matrix-org/matrix-synapse-saml-mozilla/tree/master/matrix_synapse_saml_mozilla/res. As the comment says, we might want to think about making them customisable, but that can be a follow-up. Fixes #8876. --- changelog.d/8942.feature | 1 + docs/sample_config.yaml | 5 +- docs/sso_mapping_providers.md | 28 +-- synapse/app/homeserver.py | 2 + synapse/config/oidc_config.py | 5 +- synapse/handlers/oidc_handler.py | 59 +++---- synapse/handlers/sso.py | 254 ++++++++++++++++++++++++++- synapse/res/username_picker/index.html | 19 ++ synapse/res/username_picker/script.js | 95 ++++++++++ synapse/res/username_picker/style.css | 27 +++ synapse/rest/synapse/client/pick_username.py | 88 ++++++++++ synapse/types.py | 8 +- tests/handlers/test_oidc.py | 143 ++++++++++++++- tests/unittest.py | 8 +- 14 files changed, 683 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8942.feature create mode 100644 synapse/res/username_picker/index.html create mode 100644 synapse/res/username_picker/script.js create mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/rest/synapse/client/pick_username.py (limited to 'tests') diff --git a/changelog.d/8942.feature b/changelog.d/8942.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8942.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 549c581a97..077cb619c7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1825,9 +1825,10 @@ oidc_config: # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{ user.preferred_username }}" + #localpart_template: "{{ user.preferred_username }}" # Jinja2 template for the display name to set on first login. # diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 7714b1d844..e1d6ede7ba 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -15,12 +15,18 @@ where SAML mapping providers come into play. SSO mapping providers are currently supported for OpenID and SAML SSO configurations. Please see the details below for how to implement your own. -It is the responsibility of the mapping provider to normalise the SSO attributes -and map them to a valid Matrix ID. The -[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers) -has some information about what is considered valid. Alternately an easy way to -ensure it is valid is to use a Synapse utility function: -`synapse.types.map_username_to_mxid_localpart`. +It is up to the mapping provider whether the user should be assigned a predefined +Matrix ID based on the SSO attributes, or if the user should be allowed to +choose their own username. + +In the first case - where users are automatically allocated a Matrix ID - it is +the responsibility of the mapping provider to normalise the SSO attributes and +map them to a valid Matrix ID. The [specification for Matrix +IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some +information about what is considered valid. + +If the mapping provider does not assign a Matrix ID, then Synapse will +automatically serve an HTML page allowing the user to pick their own username. External mapping providers are provided to Synapse in the form of an external Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere, @@ -80,8 +86,9 @@ A custom mapping provider must specify the following methods: with failures=1. The method should then return a different `localpart` value, such as `john.doe1`. - Returns a dictionary with two keys: - - localpart: A required string, used to generate the Matrix ID. - - displayname: An optional string, the display name for the user. + - `localpart`: A string, used to generate the Matrix ID. If this is + `None`, the user is prompted to pick their own username. + - `displayname`: An optional string, the display name for the user. * `get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: @@ -165,12 +172,13 @@ A custom mapping provider must specify the following methods: redirected to. - This method must return a dictionary, which will then be used by Synapse to build a new user. The following keys are allowed: - * `mxid_localpart` - Required. The mxid localpart of the new user. + * `mxid_localpart` - The mxid localpart of the new user. If this is + `None`, the user is prompted to pick their own username. * `displayname` - The displayname of the new user. If not provided, will default to the value of `mxid_localpart`. * `emails` - A list of emails for the new user. If not provided, will default to an empty list. - + Alternatively it can raise a `synapse.api.errors.RedirectException` to redirect the user to another page. This is useful to prompt the user for additional information, e.g. if you want them to provide their own username. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bbb7407838..8d9b53be53 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), + "/_synapse/client/pick_username": pick_username_resource(self), } ) diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 1abf8ed405..4e3055282d 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -203,9 +203,10 @@ class OIDCConfig(Config): # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{{{ user.preferred_username }}}}" + #localpart_template: "{{{{ user.preferred_username }}}}" # Jinja2 template for the display name to set on first login. # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index cbd11a1382..709f8dfc13 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -947,7 +947,7 @@ class OidcHandler(BaseHandler): UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": str, "display_name": Optional[str]} + "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} ) C = TypeVar("C") @@ -1028,10 +1028,10 @@ env = Environment(finalize=jinja_finalize) @attr.s class JinjaOidcMappingConfig: - subject_claim = attr.ib() # type: str - localpart_template = attr.ib() # type: Template - display_name_template = attr.ib() # type: Optional[Template] - extra_attributes = attr.ib() # type: Dict[str, Template] + subject_claim = attr.ib(type=str) + localpart_template = attr.ib(type=Optional[Template]) + display_name_template = attr.ib(type=Optional[Template]) + extra_attributes = attr.ib(type=Dict[str, Template]) class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1047,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - if "localpart_template" not in config: - raise ConfigError( - "missing key: oidc_config.user_mapping_provider.config.localpart_template" - ) - - try: - localpart_template = env.from_string(config["localpart_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" - % (e,) - ) + localpart_template = None # type: Optional[Template] + if "localpart_template" in config: + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template", path=["localpart_template"] + ) from e display_name_template = None # type: Optional[Template] if "display_name_template" in config: @@ -1066,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name_template = env.from_string(config["display_name_template"]) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" - % (e,) - ) + "invalid jinja template", path=["display_name_template"] + ) from e extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: extra_attributes_config = config.get("extra_attributes") or {} if not isinstance(extra_attributes_config, dict): - raise ConfigError( - "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" - ) + raise ConfigError("must be a dict", path=["extra_attributes"]) for key, value in extra_attributes_config.items(): try: extra_attributes[key] = env.from_string(value) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" - % (key, e) - ) + "invalid jinja template", path=["extra_attributes", key] + ) from e return JinjaOidcMappingConfig( subject_claim=subject_claim, @@ -1100,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): async def map_user_attributes( self, userinfo: UserInfo, token: Token, failures: int ) -> UserAttributeDict: - localpart = self._config.localpart_template.render(user=userinfo).strip() + localpart = None + + if self._config.localpart_template: + localpart = self._config.localpart_template.render(user=userinfo).strip() - # Ensure only valid characters are included in the MXID. - localpart = map_username_to_mxid_localpart(localpart) + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) - # Append suffix integer if last call to this function failed to produce - # a usable mxid. - localpart += str(failures) if failures else "" + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" display_name = None # type: Optional[str] if self._config.display_name_template is not None: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index f054b66a53..548b02211b 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional import attr +from typing_extensions import NoReturn from twisted.web.http import Request -from synapse.api.errors import RedirectException +from synapse.api.errors import RedirectException, SynapseError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,16 +42,52 @@ class MappingException(Exception): @attr.s class UserAttributes: - localpart = attr.ib(type=str) + # the localpart of the mxid that the mapper has assigned to the user. + # if `None`, the mapper has not picked a userid, and the user should be prompted to + # enter one. + localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) emails = attr.ib(type=List[str], default=attr.Factory(list)) +@attr.s(slots=True) +class UsernameMappingSession: + """Data we track about SSO sessions""" + + # A unique identifier for this SSO provider, e.g. "oidc" or "saml". + auth_provider_id = attr.ib(type=str) + + # user ID on the IdP server + remote_user_id = attr.ib(type=str) + + # attributes returned by the ID mapper + display_name = attr.ib(type=Optional[str]) + emails = attr.ib(type=List[str]) + + # An optional dictionary of extra attributes to be provided to the client in the + # login response. + extra_login_attributes = attr.ib(type=Optional[JsonDict]) + + # where to redirect the client back to + client_redirect_url = attr.ib(type=str) + + # expiry time for the session, in milliseconds + expiry_time_ms = attr.ib(type=int) + + +# the HTTP cookie used to track the mapping session id +USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" + + class SsoHandler: # The number of attempts to ask the mapping provider for when generating an MXID. _MAP_USERNAME_RETRIES = 1000 + # the time a UsernameMappingSession remains valid for + _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000 + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastore() self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() @@ -59,6 +97,9 @@ class SsoHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + # a map from session id to session data + self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + def render_error( self, request, error: str, error_description: Optional[str] = None ) -> None: @@ -206,6 +247,18 @@ class SsoHandler: # Otherwise, generate a new user. if not user_id: attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + + if attributes.localpart is None: + # the mapper doesn't return a username. bail out with a redirect to + # the username picker. + await self._redirect_to_username_picker( + auth_provider_id, + remote_user_id, + attributes, + client_redirect_url, + extra_login_attributes, + ) + user_id = await self._register_mapped_user( attributes, auth_provider_id, @@ -243,10 +296,8 @@ class SsoHandler: ) if not attributes.localpart: - raise MappingException( - "Error parsing SSO response: SSO mapping provider plugin " - "did not return a localpart value" - ) + # the mapper has not picked a localpart + return attributes # Check if this mxid already exists user_id = UserID(attributes.localpart, self._server_name).to_string() @@ -261,6 +312,59 @@ class SsoHandler: ) return attributes + async def _redirect_to_username_picker( + self, + auth_provider_id: str, + remote_user_id: str, + attributes: UserAttributes, + client_redirect_url: str, + extra_login_attributes: Optional[JsonDict], + ) -> NoReturn: + """Creates a UsernameMappingSession and redirects the browser + + Called if the user mapping provider doesn't return a localpart for a new user. + Raises a RedirectException which redirects the browser to the username picker. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + attributes: the user attributes returned by the user mapping provider. + + client_redirect_url: The redirect URL passed in by the client, which we + will eventually redirect back to. + + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. + + Raises: + RedirectException + """ + session_id = random_string(16) + now = self._clock.time_msec() + session = UsernameMappingSession( + auth_provider_id=auth_provider_id, + remote_user_id=remote_user_id, + display_name=attributes.display_name, + emails=attributes.emails, + client_redirect_url=client_redirect_url, + expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS, + extra_login_attributes=extra_login_attributes, + ) + + self._username_mapping_sessions[session_id] = session + logger.info("Recorded registration session id %s", session_id) + + # Set the cookie and redirect to the username picker + e = RedirectException(b"/_synapse/client/pick_username") + e.cookies.append( + b"%s=%s; path=/" + % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) + ) + raise e + async def _register_mapped_user( self, attributes: UserAttributes, @@ -269,9 +373,38 @@ class SsoHandler: user_agent: str, ip_address: str, ) -> str: + """Register a new SSO user. + + This is called once we have successfully mapped the remote user id onto a local + user id, one way or another. + + Args: + attributes: user attributes returned by the user mapping provider, + including a non-empty localpart. + + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + user_agent: The user-agent in the HTTP request (used for potential + shadow-banning.) + + ip_address: The IP address of the requester (used for potential + shadow-banning.) + + Raises: + a MappingException if the localpart is invalid. + + a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart + is already taken. + """ + # Since the localpart is provided via a potentially untrusted module, # ensure the MXID is valid before registering. - if contains_invalid_mxid_characters(attributes.localpart): + if not attributes.localpart or contains_invalid_mxid_characters( + attributes.localpart + ): raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) logger.debug("Mapped SSO user to local part %s", attributes.localpart) @@ -326,3 +459,108 @@ class SsoHandler: await self._auth_handler.complete_sso_ui_auth( user_id, ui_auth_session_id, request ) + + async def check_username_availability( + self, localpart: str, session_id: str, + ) -> bool: + """Handle an "is username available" callback check + + Args: + localpart: desired localpart + session_id: the session id for the username picker + Returns: + True if the username is available + Raises: + SynapseError if the localpart is invalid or the session is unknown + """ + + # make sure that there is a valid mapping session, to stop people dictionary- + # scanning for accounts + + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info( + "[session %s] Checking for availability of username %s", + session_id, + localpart, + ) + + if contains_invalid_mxid_characters(localpart): + raise SynapseError(400, "localpart is invalid: %s" % (localpart,)) + user_id = UserID(localpart, self._server_name).to_string() + user_infos = await self._store.get_users_by_id_case_insensitive(user_id) + + logger.info("[session %s] users: %s", session_id, user_infos) + return not user_infos + + async def handle_submit_username_request( + self, request: SynapseRequest, localpart: str, session_id: str + ) -> None: + """Handle a request to the username-picker 'submit' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + localpart: localpart requested by the user + session_id: ID of the username mapping session, extracted from a cookie + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info("[session %s] Registering localpart %s", session_id, localpart) + + attributes = UserAttributes( + localpart=localpart, + display_name=session.display_name, + emails=session.emails, + ) + + # the following will raise a 400 error if the username has been taken in the + # meantime. + user_id = await self._register_mapped_user( + attributes, + session.auth_provider_id, + session.remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + logger.info("[session %s] Registered userid %s", session_id, user_id) + + # delete the mapping session and the cookie + del self._username_mapping_sessions[session_id] + + # delete the cookie + request.addCookie( + USERNAME_MAPPING_SESSION_COOKIE_NAME, + b"", + expires=b"Thu, 01 Jan 1970 00:00:00 GMT", + path=b"/", + ) + + await self._auth_handler.complete_sso_login( + user_id, + request, + session.client_redirect_url, + session.extra_login_attributes, + ) + + def _expire_old_sessions(self): + to_expire = [] + now = int(self._clock.time_msec()) + + for session_id, session in self._username_mapping_sessions.items(): + if session.expiry_time_ms <= now: + to_expire.append(session_id) + + for session_id in to_expire: + logger.info("Expiring mapping session %s", session_id) + del self._username_mapping_sessions[session_id] diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html new file mode 100644 index 0000000000..37ea8bb6d8 --- /dev/null +++ b/synapse/res/username_picker/index.html @@ -0,0 +1,19 @@ + + + + Synapse Login + + + +
+
+ + + +
+ + + +
+ + diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js new file mode 100644 index 0000000000..416a7c6f41 --- /dev/null +++ b/synapse/res/username_picker/script.js @@ -0,0 +1,95 @@ +let inputField = document.getElementById("field-username"); +let inputForm = document.getElementById("form"); +let submitButton = document.getElementById("button-submit"); +let message = document.getElementById("message"); + +// Submit username and receive response +function showMessage(messageText) { + // Unhide the message text + message.classList.remove("hidden"); + + message.textContent = messageText; +}; + +function doSubmit() { + showMessage("Success. Please wait a moment for your browser to redirect."); + + // remove the event handler before re-submitting the form. + delete inputForm.onsubmit; + inputForm.submit(); +} + +function onResponse(response) { + // Display message + showMessage(response); + + // Enable submit button and input field + submitButton.classList.remove('button--disabled'); + submitButton.value = "Submit"; +}; + +let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); +function usernameIsValid(username) { + return !allowedUsernameCharacters.test(username); +} +let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; + +function buildQueryString(params) { + return Object.keys(params) + .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) + .join('&'); +} + +function submitUsername(username) { + if(username.length == 0) { + onResponse("Please enter a username."); + return; + } + if(!usernameIsValid(username)) { + onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); + return; + } + + // if this browser doesn't support fetch, skip the availability check. + if(!window.fetch) { + doSubmit(); + return; + } + + let check_uri = 'check?' + buildQueryString({"username": username}); + fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw text; }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + throw json.error; + } else if(json.available) { + doSubmit(); + } else { + onResponse("This username is not available, please choose another."); + } + }).catch((err) => { + onResponse("Error checking username availability: " + err); + }); +} + +function clickSubmit() { + event.preventDefault(); + if(submitButton.classList.contains('button--disabled')) { return; } + + // Disable submit button and input field + submitButton.classList.add('button--disabled'); + + // Submit username + submitButton.value = "Checking..."; + submitUsername(inputField.value); +}; + +inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css new file mode 100644 index 0000000000..745bd4c684 --- /dev/null +++ b/synapse/res/username_picker/style.css @@ -0,0 +1,27 @@ +input[type="text"] { + font-size: 100%; + background-color: #ededf0; + border: 1px solid #fff; + border-radius: .2em; + padding: .5em .9em; + display: block; + width: 26em; +} + +.button--disabled { + border-color: #fff; + background-color: transparent; + color: #000; + text-transform: none; +} + +.hidden { + display: none; +} + +.tooltip { + background-color: #f9f9fa; + padding: 1em; + margin: 1em 0; +} + diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py new file mode 100644 index 0000000000..d3b6803e65 --- /dev/null +++ b/synapse/rest/synapse/client/pick_username.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import pkg_resources + +from twisted.web.http import Request +from twisted.web.resource import Resource +from twisted.web.static import File + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def pick_username_resource(hs: "HomeServer") -> Resource: + """Factory method to generate the username picker resource. + + This resource gets mounted under /_synapse/client/pick_username. The top-level + resource is just a File resource which serves up the static files in the resources + "res" directory, but it has a couple of children: + + * "submit", which does the mechanics of registering the new user, and redirects the + browser back to the client URL + + * "check": checks if a userid is free. + """ + + # XXX should we make this path customisable so that admins can restyle it? + base_path = pkg_resources.resource_filename("synapse", "res/username_picker") + + res = File(base_path) + res.putChild(b"submit", SubmitResource(hs)) + res.putChild(b"check", AvailabilityCheckResource(hs)) + + return res + + +class AvailabilityCheckResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + is_available = await self._sso_handler.check_username_availability( + localpart, session_id.decode("ascii", errors="replace") + ) + return 200, {"available": is_available} + + +class SubmitResource(DirectServeHtmlResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_POST(self, request: SynapseRequest): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + await self._sso_handler.handle_submit_username_request( + request, localpart, session_id.decode("ascii", errors="replace") + ) diff --git a/synapse/types.py b/synapse/types.py index 3ab6bdbe06..c7d4e95809 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -349,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile( ) -def map_username_to_mxid_localpart(username, case_sensitive=False): +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: """Map a username onto a string suitable for a MXID This follows the algorithm laid out at https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. Args: - username (unicode|bytes): username to be mapped - case_sensitive (bool): true if TEST and test should be mapped + username: username to be mapped + case_sensitive: true if TEST and test should be mapped onto different mxids Returns: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index c54f1c5797..368d600b33 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from urllib.parse import parse_qs, urlparse +import re +from typing import Dict +from urllib.parse import parse_qs, urlencode, urlparse from mock import ANY, Mock, patch import pymacaroons +from twisted.web.resource import Resource + +from synapse.api.errors import RedirectException from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException +from synapse.rest.client.v1 import login +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.server import HomeServer from synapse.types import UserID @@ -793,6 +800,140 @@ class OidcHandlerTestCase(HomeserverTestCase): "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) + def test_empty_localpart(self): + """Attempts to map onto an empty localpart should be rejected.""" + userinfo = { + "sub": "tester", + "username": "", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.username }}"} + } + } + } + ) + def test_null_localpart(self): + """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" + userinfo = { + "sub": "tester", + "username": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + +class UsernamePickerTestCase(HomeserverTestCase): + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": { + "config": {"display_name_template": "{{ user.displayname }}"} + }, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) + config["oidc_config"] = oidc_config + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # first of all, mock up an OIDC callback to the OidcHandler, which should + # raise a RedirectException + userinfo = {"sub": "tester", "displayname": "Jonny"} + f = self.get_failure( + _make_callback_with_userinfo( + self.hs, userinfo, client_redirect_url=client_redirect_url + ), + RedirectException, + ) + + # check the Location and cookies returned by the RedirectException + self.assertEqual(f.value.location, b"/_synapse/client/pick_username") + cookieheader = f.value.cookies[0] + regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") + m = regex.search(cookieheader) + if not m: + self.fail("cookie header %s does not match %s" % (cookieheader, regex)) + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + session_id = m.group(1).decode("ascii") + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map" + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = f.value.location + b"/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", cookieheader), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") + async def _make_callback_with_userinfo( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" diff --git a/tests/unittest.py b/tests/unittest.py index 39e5e7b85c..af7f752c5a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Dict, Optional, Type, TypeVar, Union +from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -383,6 +383,9 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -405,6 +408,8 @@ class HomeserverTestCase(TestCase): true (the default), will pump the test reactor until the the renderer tells the channel the request is finished. + custom_headers: (name, value) pairs to add as request headers + Returns: The FakeChannel object which stores the result of the request. """ @@ -420,6 +425,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin, content_is_form, await_result, + custom_headers, ) def setup_test_homeserver(self, *args, **kwargs): -- cgit 1.5.1 From 5e7d75daa2d53ea74c75f9a0db52c5590fc22038 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Dec 2020 15:00:34 +0000 Subject: Fix mainline ordering in state res v2 (#8971) This had two effects 1) it'd give the wrong answer and b) would iterate *all* power levels in the auth chain of each event. The latter of which can be *very* expensive for certain types of IRC bridge rooms that have large numbers of power level changes. --- changelog.d/8971.bugfix | 1 + synapse/state/v2.py | 2 +- tests/state/test_v2.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8971.bugfix (limited to 'tests') diff --git a/changelog.d/8971.bugfix b/changelog.d/8971.bugfix new file mode 100644 index 0000000000..c3e44b8c0b --- /dev/null +++ b/changelog.d/8971.bugfix @@ -0,0 +1 @@ +Fix small bug in v2 state resolution algorithm, which could also cause performance issues for rooms with large numbers of power levels. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index f85124bf81..e585954bd8 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -658,7 +658,7 @@ async def _get_mainline_depth_for_event( # We do an iterative search, replacing `event with the power level in its # auth events (if any) while tmp_event: - depth = mainline_map.get(event.event_id) + depth = mainline_map.get(tmp_event.event_id) if depth is not None: return depth diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 09f4f32a02..77c72834f2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -88,7 +88,7 @@ class FakeEvent: event_dict = { "auth_events": [(a, {}) for a in auth_events], "prev_events": [(p, {}) for p in prev_events], - "event_id": self.node_id, + "event_id": self.event_id, "sender": self.sender, "type": self.type, "content": self.content, @@ -381,6 +381,61 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) + def test_mainline_sort(self): + """Tests that the mainline ordering works correctly. + """ + + events = [ + FakeEvent( + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": {ALICE: 100, BOB: 50}, + "events": {EventTypes.PowerLevels: 100}, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + ] + + edges = [ + ["END", "T3", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T4", "PB", "PA1"], + ] + + # We expect T3 to be picked as the other topics are pointing at older + # power levels. Note that without mainline ordering we'd pick T4 due to + # it being sent *after* T3. + expected_state_ids = ["T3", "PA2"] + + self.do_check(events, edges, expected_state_ids) + def do_check(self, events, edges, expected_state_ids): """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` -- cgit 1.5.1 From d781a81e692563c5785e3efd4aa2487696b9c995 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Dec 2020 15:37:19 +0000 Subject: Allow server admin to get admin bit in rooms where local user is an admin (#8756) This adds an admin API that allows a server admin to get power in a room if a local user has power in a room. Will also invite the user if they're not in the room and its a private room. Can specify another user (rather than the admin user) to be granted power. Co-authored-by: Matthew Hodgson --- changelog.d/8756.feature | 1 + docs/admin_api/rooms.md | 20 +++++- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 136 +++++++++++++++++++++++++++++++++++++++- tests/rest/admin/test_room.py | 138 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8756.feature (limited to 'tests') diff --git a/changelog.d/8756.feature b/changelog.d/8756.feature new file mode 100644 index 0000000000..03eb79fb0a --- /dev/null +++ b/changelog.d/8756.feature @@ -0,0 +1 @@ +Add admin API that lets server admins get power in rooms in which local users have power. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index d7b1740fe3..9e560003a9 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -8,6 +8,7 @@ * [Parameters](#parameters-1) * [Response](#response) * [Undoing room shutdowns](#undoing-room-shutdowns) +- [Make Room Admin API](#make-room-admin-api) # List Room API @@ -467,6 +468,7 @@ The following fields are returned in the JSON response body: the old room to the new. * `new_room_id` - A string representing the room ID of the new room. + ## Undoing room shutdowns *Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, @@ -492,4 +494,20 @@ You will have to manually handle, if you so choose, the following: * Aliases that would have been redirected to the Content Violation room. * Users that would have been booted from the room (and will have been force-joined to the Content Violation room). -* Removal of the Content Violation room if desired. \ No newline at end of file +* Removal of the Content Violation room if desired. + + +# Make Room Admin API + +Grants another user the highest power available to a local user who is in the room. +If the user is not in the room, and it is not publicly joinable, then invite the user. + +By default the server admin (the caller) is granted power, but another user can +optionally be specified, e.g.: + +``` + POST /_synapse/admin/v1/rooms//make_room_admin + { + "user_id": "@foo:example.com" + } +``` diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 55ddebb4fe..6f7dc06503 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, @@ -228,6 +229,7 @@ def register_servlets(hs, http_server): EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) + MakeRoomAdminRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index b902af8028..ab7cc9102a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -16,8 +16,8 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules -from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -37,6 +37,7 @@ from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester if TYPE_CHECKING: from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -367,3 +368,134 @@ class JoinRoomAliasServlet(RestServlet): ) return 200, {"room_id": room_id} + + +class MakeRoomAdminRestServlet(RestServlet): + """Allows a server admin to get power in a room if a local user has power in + a room. Will also invite the user if they're not in the room and it's a + private room. Can specify another user (rather than the admin user) to be + granted power, e.g.: + + POST/_synapse/admin/v1/rooms//make_room_admin + { + "user_id": "@foo:example.com" + } + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state_handler = hs.get_state_handler() + self.is_mine_id = hs.is_mine_id + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + content = parse_json_object_from_request(request, allow_empty_body=True) + + # Resolve to a room ID, if necessary. + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + # Which user to grant room admin rights to. + user_to_add = content.get("user_id", requester.user.to_string()) + + # Figure out which local users currently have power in the room, if any. + room_state = await self.state_handler.get_current_state(room_id) + if not room_state: + raise SynapseError(400, "Server not in room") + + create_event = room_state[(EventTypes.Create, "")] + power_levels = room_state.get((EventTypes.PowerLevels, "")) + + if power_levels is not None: + # We pick the local user with the highest power. + user_power = power_levels.content.get("users", {}) + admin_users = [ + user_id for user_id in user_power if self.is_mine_id(user_id) + ] + admin_users.sort(key=lambda user: user_power[user]) + + if not admin_users: + raise SynapseError(400, "No local admin user in room") + + admin_user_id = admin_users[-1] + + pl_content = power_levels.content + else: + # If there is no power level events then the creator has rights. + pl_content = {} + admin_user_id = create_event.sender + if not self.is_mine_id(admin_user_id): + raise SynapseError( + 400, "No local admin user in room", + ) + + # Grant the user power equal to the room admin by attempting to send an + # updated power level event. + new_pl_content = dict(pl_content) + new_pl_content["users"] = dict(pl_content.get("users", {})) + new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] + + fake_requester = create_requester( + admin_user_id, authenticated_entity=requester.authenticated_entity, + ) + + try: + await self.event_creation_handler.create_and_send_nonmember_event( + fake_requester, + event_dict={ + "content": new_pl_content, + "sender": admin_user_id, + "type": EventTypes.PowerLevels, + "state_key": "", + "room_id": room_id, + }, + ) + except AuthError: + # The admin user we found turned out not to have enough power. + raise SynapseError( + 400, "No local admin user in room with power to update power levels." + ) + + # Now we check if the user we're granting admin rights to is already in + # the room. If not and it's not a public room we invite them. + member_event = room_state.get((EventTypes.Member, user_to_add)) + is_joined = False + if member_event: + is_joined = member_event.content["membership"] in ( + Membership.JOIN, + Membership.INVITE, + ) + + if is_joined: + return 200, {} + + join_rules = room_state.get((EventTypes.JoinRules, "")) + is_public = False + if join_rules: + is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC + + if is_public: + return 200, {} + + await self.room_member_handler.update_membership( + fake_requester, + target=UserID.from_string(user_to_add), + room_id=room_id, + action=Membership.INVITE, + ) + + return 200, {} diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 014c30287a..60a5fcecf7 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -20,6 +20,7 @@ from typing import List, Optional from mock import Mock import synapse.rest.admin +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.rest.client.v1 import directory, events, login, room @@ -1432,6 +1433,143 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) +class MakeRoomAdminTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format( + self.public_room_id + ) + + def test_public_room(self): + """Test that getting admin in a public room works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_private_room(self): + """Test that getting admin in a private room works and we get invited. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False, + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room (we should have received an + # invite) and can ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_other_user(self): + """Test that giving admin in a public room works to a non-admin user works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={"user_id": self.second_user_id}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.second_user_id, tok=self.second_tok) + self.helper.change_membership( + room_id, + self.second_user_id, + "@test:test", + Membership.BAN, + tok=self.second_tok, + ) + + def test_not_enough_power(self): + """Test that we get a sensible error if there are no local room admins. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + # The creator drops admin rights in the room. + pl = self.helper.get_state( + room_id, EventTypes.PowerLevels, tok=self.creator_tok + ) + pl["users"][self.creator] = 0 + self.helper.send_state( + room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + # We expect this to fail with a 400 as there are no room admins. + # + # (Note we assert the error message to ensure that it's not denied for + # some other reason) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + channel.json_body["error"], + "No local admin user in room with power to update power levels.", + ) + + PURGE_TABLES = [ "current_state_events", "event_backward_extremities", -- cgit 1.5.1 From a8026064755362fe3e5dc00f537606d340ce242a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Dec 2020 13:00:14 -0500 Subject: Support PyJWT v2.0.0. (#8986) Tests were broken due to an API changing. The code used in Synapse proper should be compatible with both versions already. --- changelog.d/8986.misc | 1 + tests/rest/client/v1/test_login.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8986.misc (limited to 'tests') diff --git a/changelog.d/8986.misc b/changelog.d/8986.misc new file mode 100644 index 0000000000..6aefc78784 --- /dev/null +++ b/changelog.d/8986.misc @@ -0,0 +1 @@ +Support using PyJWT v2.0.0 in the test suite. diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 566776e97e..18932d7518 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -475,8 +475,12 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( @@ -680,8 +684,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token, secret=jwt_privatekey): - return jwt.encode(token, secret, "RS256").decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( -- cgit 1.5.1 From d0c3c24eb2bf12d2975093f074daa84569b12ddd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Dec 2020 07:26:29 -0500 Subject: Drop the unused local_invites table. (#8979) This table has been unused since Synapse v1.17.0. --- changelog.d/8979.misc | 1 + .../databases/main/schema/delta/58/27local_invites.sql | 18 ++++++++++++++++++ tests/rest/admin/test_room.py | 1 - 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8979.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/27local_invites.sql (limited to 'tests') diff --git a/changelog.d/8979.misc b/changelog.d/8979.misc new file mode 100644 index 0000000000..670821cf90 --- /dev/null +++ b/changelog.d/8979.misc @@ -0,0 +1 @@ +Drop the unused `local_invites` table. diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql new file mode 100644 index 0000000000..44b2a0572f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql @@ -0,0 +1,18 @@ +/* + * Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is unused since Synapse v1.17.0. +DROP TABLE local_invites; diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 60a5fcecf7..fa620f97f3 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1598,7 +1598,6 @@ PURGE_TABLES = [ "event_push_summary", "pusher_throttle", "group_summary_rooms", - "local_invites", "room_account_data", "room_tags", # "state_groups", # Current impl leaves orphaned state groups around. -- cgit 1.5.1 From 168ba00d0138c3a0f44d07370f97dd854177794f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Dec 2020 19:27:32 +0000 Subject: Fix RoomDirectoryFederationTests and make them actually run (#8998) The `RoomDirectoryFederationTests` tests were not being run unless explicitly called as an `__init__.py` file was not present in `tests/federation/transport/`. Thus the folder was not a python module, and `trial` did not look inside for any test cases to run. This was found while working on #6739. This PR adds a `__init__.py` and also fixes the test in a couple ways: - Switch to subclassing `unittest.FederatingHomeserverTestCase` instead, which sets up federation endpoints for us. - Supply a `federation_auth_origin` to `make_request` in order to more act like the request is coming from another server, instead of just an unauthenicated client requesting a federation endpoint. I found that the second point makes no difference to the test passing, but felt like the right thing to do if we're testing over federation. --- changelog.d/8998.misc | 1 + tests/federation/transport/__init__.py | 0 tests/federation/transport/test_server.py | 39 ++++++++++++++----------------- 3 files changed, 19 insertions(+), 21 deletions(-) create mode 100644 changelog.d/8998.misc create mode 100644 tests/federation/transport/__init__.py (limited to 'tests') diff --git a/changelog.d/8998.misc b/changelog.d/8998.misc new file mode 100644 index 0000000000..81346694bd --- /dev/null +++ b/changelog.d/8998.misc @@ -0,0 +1 @@ +Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI. \ No newline at end of file diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 212fb79a00..85500e169c 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,34 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - from tests import unittest from tests.unittest import override_config -class RoomDirectoryFederationTests(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - class Authenticator: - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - +class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"allow_public_rooms_over_federation": False}) def test_blocked_public_room_list_over_federation(self): - channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms") + """Test that unauthenticated requests to the public rooms directory 403 when + allow_public_rooms_over_federation is False. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", + ) self.assertEquals(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): - channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms") + """Test that unauthenticated requests to the public rooms directory 200 when + allow_public_rooms_over_federation is True. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", + ) self.assertEquals(200, channel.code) -- cgit 1.5.1 From 0eccf531466d762ede0dd365284a8465bfb18d0f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Sun, 3 Jan 2021 11:25:44 -0500 Subject: Use the SSO handler helpers for CAS registration/login. (#8856) --- changelog.d/8856.misc | 1 + synapse/handlers/cas_handler.py | 112 +++++++++++++++++++++++++------------ synapse/handlers/sso.py | 4 +- tests/handlers/test_cas.py | 121 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+), 39 deletions(-) create mode 100644 changelog.d/8856.misc create mode 100644 tests/handlers/test_cas.py (limited to 'tests') diff --git a/changelog.d/8856.misc b/changelog.d/8856.misc new file mode 100644 index 0000000000..1507073e4f --- /dev/null +++ b/changelog.d/8856.misc @@ -0,0 +1 @@ +Properly store the mapping of external ID to Matrix ID for CAS users. diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index e9891e1316..fca210a5a6 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -22,6 +22,7 @@ import attr from twisted.web.client import PartialDownloadError from synapse.api.errors import HttpResponseException +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart @@ -62,6 +63,7 @@ class CasHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname + self._store = hs.get_datastore() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -72,6 +74,9 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() + # identifier for the external_ids table + self._auth_provider_id = "cas" + self._sso_handler = hs.get_sso_handler() def _build_service_param(self, args: Dict[str, str]) -> str: @@ -267,6 +272,14 @@ class CasHandler: This should be the UI Auth session id. """ + # first check if we're doing a UIA + if session: + return await self._sso_handler.complete_sso_ui_auth_request( + self._auth_provider_id, cas_response.username, session, request, + ) + + # otherwise, we're handling a login request. + # Ensure that the attributes of the logged in user meet the required # attributes. for required_attribute, required_value in self._cas_required_attributes.items(): @@ -293,54 +306,79 @@ class CasHandler: ) return - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - - # Get the matrix ID from the CAS username. - user_id = await self._map_cas_user_to_matrix_user( - cas_response, user_agent, ip_address - ) + # Call the mapper to register/login the user - if session: - await self._auth_handler.complete_sso_ui_auth( - user_id, session, request, - ) - else: - # If this not a UI auth request than there must be a redirect URL. - assert client_redirect_url + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url is not None - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url - ) + try: + await self._complete_cas_login(cas_response, request, client_redirect_url) + except MappingException as e: + logger.exception("Could not map user") + self._sso_handler.render_error(request, "mapping_error", str(e)) - async def _map_cas_user_to_matrix_user( - self, cas_response: CasResponse, user_agent: str, ip_address: str, - ) -> str: + async def _complete_cas_login( + self, + cas_response: CasResponse, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: """ - Given a CAS username, retrieve the user ID for it and possibly register the user. + Given a CAS response, complete the login flow + + Retrieves the remote user ID, registers the user if necessary, and serves + a redirect back to the client with a login-token. Args: cas_response: The parsed CAS response. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. - Returns: - The user ID associated with this response. + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. """ - + # Note that CAS does not support a mapping provider, so the logic is hard-coded. localpart = map_username_to_mxid_localpart(cas_response.username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - displayname = cas_response.attributes.get(self._cas_displayname_attribute, None) + async def cas_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Map from CAS attributes to user attributes. + """ + # Due to the grandfathering logic matching any previously registered + # mxids it isn't expected for there to be any failures. + if failures: + raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + + display_name = cas_response.attributes.get( + self._cas_displayname_attribute, None + ) + + return UserAttributes(localpart=localpart, display_name=display_name) - # If the user does not exist, register it. - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=displayname, - user_agent_ips=[(user_agent, ip_address)], + async def grandfather_existing_users() -> Optional[str]: + # Since CAS did not always use the user_external_ids table, always + # to attempt to map to existing users. + user_id = UserID(localpart, self._hostname).to_string() + + logger.debug( + "Looking for existing account based on mapped %s", user_id, ) - return registered_user_id + users = await self._store.get_users_by_id_case_insensitive(user_id) + if users: + registered_user_id = list(users.keys())[0] + logger.info("Grandfathering mapping to %s", registered_user_id) + return registered_user_id + + return None + + await self._sso_handler.complete_sso_login_request( + self._auth_provider_id, + cas_response.username, + request, + client_redirect_url, + cas_response_to_user_attributes, + grandfather_existing_users, + ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index b0a8c8c7d2..33cd6bc178 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -173,7 +173,7 @@ class SsoHandler: request: SynapseRequest, client_redirect_url: str, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], - grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], + grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, ) -> None: """ @@ -241,7 +241,7 @@ class SsoHandler: ) # Check for grandfathering of users. - if not user_id and grandfather_existing_users: + if not user_id: user_id = await grandfather_existing_users() if user_id: # Future logins should also match this user ID. diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py new file mode 100644 index 0000000000..bd7a1b6891 --- /dev/null +++ b/tests/handlers/test_cas.py @@ -0,0 +1,121 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.handlers.cas_handler import CasResponse + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" +SERVER_URL = "https://issuer/" + + +class CasHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + cas_config = { + "enabled": True, + "server_url": SERVER_URL, + "service_url": BASE_URL, + } + config["cas_config"] = cas_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_cas_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + def test_map_cas_user_to_user(self): + """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_existing_user(self): + """Existing users can log in with CAS account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # Map a user via SSO. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + # Subsequent calls should map to the same mxid. + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_invalid_localpart(self): + """CAS automaps invalid characters to base-64 encoding.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("föö", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + ) + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) -- cgit 1.5.1 From 1c9a8505623475ae28067e6f0e8e74ede70c728a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Jan 2021 10:04:50 -0500 Subject: Add type hints to the crypto module. (#8999) --- changelog.d/8999.misc | 1 + mypy.ini | 2 + synapse/crypto/context_factory.py | 2 +- synapse/crypto/event_signing.py | 29 ++-- synapse/crypto/keyring.py | 206 +++++++++++++++++------------ synapse/federation/transport/server.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 9 +- synapse/storage/databases/main/keys.py | 10 +- tests/crypto/test_keyring.py | 10 +- 9 files changed, 158 insertions(+), 113 deletions(-) create mode 100644 changelog.d/8999.misc (limited to 'tests') diff --git a/changelog.d/8999.misc b/changelog.d/8999.misc new file mode 100644 index 0000000000..3987204f06 --- /dev/null +++ b/changelog.d/8999.misc @@ -0,0 +1 @@ +Add type hints to the crypto module. diff --git a/mypy.ini b/mypy.ini index a54f34fe24..6a53abfaa9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,7 @@ files = synapse/api, synapse/appservice, synapse/config, + synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, synapse/events/validator.py, @@ -75,6 +76,7 @@ files = synapse/storage/background_updates.py, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 57fd426e87..74b67b230a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -227,7 +227,7 @@ class ConnectionVerifier: # This code is based on twisted.internet.ssl.ClientTLSOptions. - def __init__(self, hostname: bytes, verify_certs): + def __init__(self, hostname: bytes, verify_certs: bool): self._verify_certs = verify_certs _decoded = hostname.decode("ascii") diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 0422c43fab..8fb116ae18 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -18,7 +18,7 @@ import collections.abc import hashlib import logging -from typing import Dict +from typing import Any, Callable, Dict, Tuple from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase from synapse.events.utils import prune_event, prune_event_dict from synapse.types import JsonDict logger = logging.getLogger(__name__) +Hasher = Callable[[bytes], "hashlib._Hash"] -def check_event_content_hash(event, hash_algorithm=hashlib.sha256): + +def check_event_content_hash( + event: EventBase, hash_algorithm: Hasher = hashlib.sha256 +) -> bool: """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) logger.debug( @@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): return message_hash_bytes == expected_hash -def compute_content_hash(event_dict, hash_algorithm): +def compute_content_hash( + event_dict: Dict[str, Any], hash_algorithm: Hasher +) -> Tuple[str, bytes]: """Compute the content hash of an event, which is the hash of the unredacted event. Args: - event_dict (dict): The unredacted event as a dict + event_dict: The unredacted event as a dict hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ event_dict = dict(event_dict) event_dict.pop("age_ts", None) @@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm): return hashed.name, hashed.digest() -def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): +def compute_event_reference_hash( + event, hash_algorithm: Hasher = hashlib.sha256 +) -> Tuple[str, bytes]: """Computes the event reference hash. This is the hash of the redacted event. Args: - event (FrozenEvent) + event hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ tmp_event = prune_event(event) event_dict = tmp_event.get_pdu_json() @@ -156,7 +163,7 @@ def add_hashes_and_signatures( event_dict: JsonDict, signature_name: str, signing_key: SigningKey, -): +) -> None: """Add content hash and sign the event Args: diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index f23eacc0d7..902128a23c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import urllib from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from signedjson.key import ( @@ -40,6 +42,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.config.key import TrustedKeyServer from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, @@ -47,11 +50,15 @@ from synapse.logging.context import ( run_in_background, ) from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -61,16 +68,17 @@ class VerifyJsonRequest: A request to verify a JSON object. Attributes: - server_name(str): The name of the server to verify against. - - key_ids(set[str]): The set of key_ids to that could be used to verify the - JSON object + server_name: The name of the server to verify against. - json_object(dict): The JSON object to verify. + json_object: The JSON object to verify. - minimum_valid_until_ts (int): time at which we require the signing key to + minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) + request_name: The name of the request. + + key_ids: The set of key_ids to that could be used to verify the JSON object + key_ready (Deferred[str, str, nacl.signing.VerifyKey]): A deferred (server_name, key_id, verify_key) tuple that resolves when a verify key has been fetched. The deferreds' callbacks are run with no @@ -80,12 +88,12 @@ class VerifyJsonRequest: errbacks with an M_UNAUTHORIZED SynapseError. """ - server_name = attr.ib() - json_object = attr.ib() - minimum_valid_until_ts = attr.ib() - request_name = attr.ib() - key_ids = attr.ib(init=False) - key_ready = attr.ib(default=attr.Factory(defer.Deferred)) + server_name = attr.ib(type=str) + json_object = attr.ib(type=JsonDict) + minimum_valid_until_ts = attr.ib(type=int) + request_name = attr.ib(type=str) + key_ids = attr.ib(init=False, type=List[str]) + key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) def __attrs_post_init__(self): self.key_ids = signature_ids(self.json_object, self.server_name) @@ -96,7 +104,9 @@ class KeyLookupError(ValueError): class Keyring: - def __init__(self, hs, key_fetchers=None): + def __init__( + self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None + ): self.clock = hs.get_clock() if key_fetchers is None: @@ -112,22 +122,26 @@ class Keyring: # completes. # # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} + self.key_downloads = {} # type: Dict[str, defer.Deferred] def verify_json_for_server( - self, server_name, json_object, validity_time, request_name - ): + self, + server_name: str, + json_object: JsonDict, + validity_time: int, + request_name: str, + ) -> defer.Deferred: """Verify that a JSON object has been signed by a given server Args: - server_name (str): name of the server which must have signed this object + server_name: name of the server which must have signed this object - json_object (dict): object to be checked + json_object: object to be checked - validity_time (int): timestamp at which we require the signing key to + validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - request_name (str): an identifier for this json object (eg, an event id) + request_name: an identifier for this json object (eg, an event id) for logging. Returns: @@ -138,12 +152,14 @@ class Keyring: requests = (req,) return make_deferred_yieldable(self._verify_objects(requests)[0]) - def verify_json_objects_for_server(self, server_and_json): + def verify_json_objects_for_server( + self, server_and_json: Iterable[Tuple[str, dict, int, str]] + ) -> List[defer.Deferred]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: - server_and_json (iterable[Tuple[str, dict, int, str]): + server_and_json: Iterable of (server_name, json_object, validity_time, request_name) tuples. @@ -164,13 +180,14 @@ class Keyring: for server_name, json_object, validity_time, request_name in server_and_json ) - def _verify_objects(self, verify_requests): + def _verify_objects( + self, verify_requests: Iterable[VerifyJsonRequest] + ) -> List[defer.Deferred]: """Does the work of verify_json_[objects_]for_server Args: - verify_requests (iterable[VerifyJsonRequest]): - Iterable of verification requests. + verify_requests: Iterable of verification requests. Returns: List: for each input item, a deferred indicating success @@ -182,7 +199,7 @@ class Keyring: key_lookups = [] handle = preserve_fn(_handle_key_deferred) - def process(verify_request): + def process(verify_request: VerifyJsonRequest) -> defer.Deferred: """Process an entry in the request list Adds a key request to key_lookups, and returns a deferred which @@ -222,18 +239,20 @@ class Keyring: return results - async def _start_key_lookups(self, verify_requests): + async def _start_key_lookups( + self, verify_requests: List[VerifyJsonRequest] + ) -> None: """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. Args: - verify_requests (List[VerifyJsonRequest]): + verify_requests: """ try: # map from server name to a set of outstanding request ids - server_to_request_ids = {} + server_to_request_ids = {} # type: Dict[str, Set[int]] for verify_request in verify_requests: server_name = verify_request.server_name @@ -275,11 +294,11 @@ class Keyring: except Exception: logger.exception("Error starting key lookups") - async def wait_for_previous_lookups(self, server_names) -> None: + async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: """Waits for any previous key lookups for the given servers to finish. Args: - server_names (Iterable[str]): list of servers which we want to look up + server_names: list of servers which we want to look up Returns: Resolves once all key lookups for the given servers have @@ -304,7 +323,7 @@ class Keyring: loop_count += 1 - def _get_server_verify_keys(self, verify_requests): + def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: """Tries to find at least one key for each verify request For each verify_request, verify_request.key_ready is called back with @@ -312,7 +331,7 @@ class Keyring: with a SynapseError if none of the keys are found. Args: - verify_requests (list[VerifyJsonRequest]): list of verify requests + verify_requests: list of verify requests """ remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @@ -366,17 +385,19 @@ class Keyring: run_in_background(do_iterations) - async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher( + self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] + ): """Use a key fetcher to attempt to satisfy some key requests Args: - fetcher (KeyFetcher): fetcher to use to fetch the keys - remaining_requests (set[VerifyJsonRequest]): outstanding key requests. + fetcher: fetcher to use to fetch the keys + remaining_requests: outstanding key requests. Any successfully-completed requests will be removed from the list. """ - # dict[str, dict[str, int]]: keys to fetch. + # The keys to fetch. # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) + missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] for verify_request in remaining_requests: # any completed requests should already have been removed @@ -438,16 +459,18 @@ class Keyring: remaining_requests.difference_update(completed) -class KeyFetcher: - async def get_keys(self, keys_to_fetch): +class KeyFetcher(metaclass=abc.ABCMeta): + @abc.abstractmethod + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ raise NotImplementedError @@ -455,31 +478,35 @@ class KeyFetcher: class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - keys_to_fetch = ( + key_ids_to_fetch = ( (server_name, key_id) for server_name, keys_for_server in keys_to_fetch.items() for key_id in keys_for_server.keys() ) - res = await self.store.get_server_verify_keys(keys_to_fetch) - keys = {} + res = await self.store.get_server_verify_keys(key_ids_to_fetch) + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys -class BaseV2KeyFetcher: - def __init__(self, hs): +class BaseV2KeyFetcher(KeyFetcher): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.config = hs.get_config() - async def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response( + self, from_server: str, response_json: JsonDict, time_added_ms: int + ) -> Dict[str, FetchKeyResult]: """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -493,16 +520,16 @@ class BaseV2KeyFetcher: to /_matrix/key/v2/query. Args: - from_server (str): the name of the server producing this result: either + from_server: the name of the server producing this result: either the origin server for a /_matrix/key/v2/server request, or the notary for a /_matrix/key/v2/query. - response_json (dict): the json-decoded Server Keys response object + response_json: the json-decoded Server Keys response object - time_added_ms (int): the timestamp to record in server_keys_json + time_added_ms: the timestamp to record in server_keys_json Returns: - Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + Map from key_id to result object """ ts_valid_until_ms = response_json["valid_until_ts"] @@ -575,21 +602,22 @@ class BaseV2KeyFetcher: class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - async def get_key(key_server): + async def get_key(key_server: TrustedKeyServer) -> Dict: try: - result = await self.get_server_verify_key_v2_indirect( + return await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) - return result except KeyLookupError as e: logger.warning( "Key lookup failed from %r: %s", key_server.server_name, e @@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ).addErrback(unwrapFirstError) ) - union_of_keys = {} + union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) return union_of_keys - async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect( + self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts - key_server (synapse.config.key.TrustedKeyServer): notary server to query for - the keys + key_server: notary server to query for the keys Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map - from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult Raises: KeyLookupError if there was an error processing the entire response from @@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - keys = {} - added_keys = [] + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] time_now_ms = self.clock.time_msec() + assert isinstance(query_response, dict) for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") @@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return keys - def _validate_perspectives_response(self, key_server, response): + def _validate_perspectives_response( + self, key_server: TrustedKeyServer, response: JsonDict + ) -> None: """Optionally check the signature on the result of a /key/query request Args: - key_server (synapse.config.key.TrustedKeyServer): the notary server that - produced this result + key_server: the notary server that produced this result - response (dict): the json-decoded Server Keys response object + response: the json-decoded Server Keys response object """ perspective_name = key_server.server_name perspective_keys = key_server.verify_keys @@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_federation_http_client() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, iterable[str]]): + keys_to_fetch: the keys to be fetched. server_name -> key_ids Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ results = {} - async def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: server_name, key_ids = key_to_fetch_item try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) @@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher): await yieldable_gather_results(get_key, keys_to_fetch.items()) return results - async def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct( + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, FetchKeyResult]: """ Args: - server_name (str): - key_ids (iterable[str]): + server_name: + key_ids: Returns: - dict[str, FetchKeyResult]: map from key ID to lookup result + Map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup """ - keys = {} # type: dict[str, FetchKeyResult] + keys = {} # type: Dict[str, FetchKeyResult] for requested_key_id in key_ids: # we may have found this key as a side-effect of asking for another. @@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) + assert isinstance(response, dict) if response["server_name"] != server_name: raise KeyLookupError( "Expected a response for server %r not %r" @@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -async def _handle_key_deferred(verify_request) -> None: +async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: """Waits for the key to become available, and then performs a verification Args: - verify_request (VerifyJsonRequest): + verify_request: Raises: SynapseError if there was a problem performing the verification diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 434718ddfc..cfd094e58f 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -144,7 +144,7 @@ class Authenticator: ): raise FederationDeniedError(origin) - if not json_request["signatures"]: + if origin is None or not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f843f02454..c57ac22e58 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Set +from typing import Dict from signedjson.sign import sign_json @@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - cache_misses = {} # type: Dict[str, Set[str]] + # Note that the value is unused. + cache_misses = {} # type: Dict[str, Dict[str, int]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 continue if key_id is not None: @@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource): ) if miss: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index f8f4bb9b3f..04ac2d0ced 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore from synapse.storage.keys import FetchKeyResult +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore): ) async def get_server_verify_keys( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: + ) -> Dict[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn, batch): + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) @@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore): # `ts_valid_until_ms`. ts_valid_until_ms = 0 - res = FetchKeyResult( + keys[(server_name, key_id)] = FetchKeyResult( verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), valid_until_ts=ts_valid_until_ms, ) - keys[(server_name, key_id)] = res - def _txn(txn): + def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: for batch in batch_iter(server_name_and_key_ids, 50): _get_keys(txn, batch) return keys diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d146f2254f..1d65ea2f9c 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock() kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Tests that we correctly handle key requests for keys we've stored with a null `ts_valid_until_ms` """ - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( @@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) - mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) -- cgit 1.5.1 From d2c616a41381c9e2d43b08d5f225b52042d94d23 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 4 Jan 2021 18:13:49 +0000 Subject: Combine the SSO Redirect Servlets (#9015) * Implement CasHandler.handle_redirect_request ... to make it match OidcHandler and SamlHandler * Clean up interface for OidcHandler.handle_redirect_request Make it accept `client_redirect_url=None`. * Clean up interface for `SamlHandler.handle_redirect_request` ... bring it into line with CAS and OIDC by making it take a Request parameter, move the magic for `client_redirect_url` for UIA into the handler, and fix the return type to be a `str` rather than a `bytes`. * Define a common protocol for SSO auth provider impls * Give SsoIdentityProvider an ID and register them * Combine the SSO Redirect servlets Now that the SsoHandler knows about the identity providers, we can combine the various *RedirectServlets into a single implementation which delegates to the right IdP. * changelog --- changelog.d/9015.feature | 1 + synapse/handlers/cas_handler.py | 35 ++++++++++---- synapse/handlers/oidc_handler.py | 15 ++++-- synapse/handlers/saml_handler.py | 25 +++++++--- synapse/handlers/sso.py | 86 +++++++++++++++++++++++++++++++++- synapse/rest/client/v1/login.py | 89 ++++++++---------------------------- synapse/rest/client/v2_alpha/auth.py | 34 ++++++-------- tests/rest/client/v1/test_login.py | 2 +- 8 files changed, 174 insertions(+), 113 deletions(-) create mode 100644 changelog.d/9015.feature (limited to 'tests') diff --git a/changelog.d/9015.feature b/changelog.d/9015.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9015.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index fca210a5a6..295974c521 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -75,10 +75,12 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() # identifier for the external_ids table - self._auth_provider_id = "cas" + self.idp_id = "cas" self._sso_handler = hs.get_sso_handler() + self._sso_handler.register_identity_provider(self) + def _build_service_param(self, args: Dict[str, str]) -> str: """ Generates a value to use as the "service" parameter when redirecting or @@ -105,7 +107,7 @@ class CasHandler: Args: ticket: The CAS ticket from the client. service_args: Additional arguments to include in the service URL. - Should be the same as those passed to `get_redirect_url`. + Should be the same as those passed to `handle_redirect_request`. Raises: CasError: If there's an error parsing the CAS response. @@ -184,16 +186,31 @@ class CasHandler: return CasResponse(user, attributes) - def get_redirect_url(self, service_args: Dict[str, str]) -> str: - """ - Generates a URL for the CAS server where the client should be redirected. + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: + """Generates a URL for the CAS server where the client should be redirected. Args: - service_args: Additional arguments to include in the final redirect URL. + request: the incoming HTTP request + client_redirect_url: the URL that we should redirect the + client to after login (or None for UI Auth). + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). Returns: - The URL to redirect the client to. + URL to redirect to """ + + if ui_auth_session_id: + service_args = {"session": ui_auth_session_id} + else: + assert client_redirect_url + service_args = {"redirectUrl": client_redirect_url.decode("utf8")} + args = urllib.parse.urlencode( {"service": self._build_service_param(service_args)} ) @@ -275,7 +292,7 @@ class CasHandler: # first check if we're doing a UIA if session: return await self._sso_handler.complete_sso_ui_auth_request( - self._auth_provider_id, cas_response.username, session, request, + self.idp_id, cas_response.username, session, request, ) # otherwise, we're handling a login request. @@ -375,7 +392,7 @@ class CasHandler: return None await self._sso_handler.complete_sso_login_request( - self._auth_provider_id, + self.idp_id, cas_response.username, request, client_redirect_url, diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 709f8dfc13..3e2b60eb7b 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -119,10 +119,12 @@ class OidcHandler(BaseHandler): self._macaroon_secret_key = hs.config.macaroon_secret_key # identifier for the external_ids table - self._auth_provider_id = "oidc" + self.idp_id = "oidc" self._sso_handler = hs.get_sso_handler() + self._sso_handler.register_identity_provider(self) + def _validate_metadata(self): """Verifies the provider metadata. @@ -475,7 +477,7 @@ class OidcHandler(BaseHandler): async def handle_redirect_request( self, request: SynapseRequest, - client_redirect_url: bytes, + client_redirect_url: Optional[bytes], ui_auth_session_id: Optional[str] = None, ) -> str: """Handle an incoming request to /login/sso/redirect @@ -499,7 +501,7 @@ class OidcHandler(BaseHandler): request: the incoming request from the browser. We'll respond to it with a redirect and a cookie. client_redirect_url: the URL that we should redirect the client to - when everything is done + when everything is done (or None for UI Auth) ui_auth_session_id: The session ID of the ongoing UI Auth (or None if this is a login). @@ -511,6 +513,9 @@ class OidcHandler(BaseHandler): state = generate_token() nonce = generate_token() + if not client_redirect_url: + client_redirect_url = b"" + cookie = self._generate_oidc_session_token( state=state, nonce=nonce, @@ -682,7 +687,7 @@ class OidcHandler(BaseHandler): return return await self._sso_handler.complete_sso_ui_auth_request( - self._auth_provider_id, remote_user_id, ui_auth_session_id, request + self.idp_id, remote_user_id, ui_auth_session_id, request ) # otherwise, it's a login @@ -923,7 +928,7 @@ class OidcHandler(BaseHandler): extra_attributes = await get_extra_attributes(userinfo, token) await self._sso_handler.complete_sso_login_request( - self._auth_provider_id, + self.idp_id, remote_user_id, request, client_redirect_url, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 5fa7ab3f8b..6106237f1f 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -73,27 +73,38 @@ class SamlHandler(BaseHandler): ) # identifier for the external_ids table - self._auth_provider_id = "saml" + self.idp_id = "saml" # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._sso_handler = hs.get_sso_handler() + self._sso_handler.register_identity_provider(self) - def handle_redirect_request( - self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None - ) -> bytes: + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: """Handle an incoming request to /login/sso/redirect Args: + request: the incoming HTTP request client_redirect_url: the URL that we should redirect the - client to when everything is done + client to after login (or None for UI Auth). ui_auth_session_id: The session ID of the ongoing UI Auth (or None if this is a login). Returns: URL to redirect to """ + if not client_redirect_url: + # Some SAML identity providers (e.g. Google) require a + # RelayState parameter on requests, so pass in a dummy redirect URL + # (which will never get used). + client_redirect_url = b"unused" + reqid, info = self._saml_client.prepare_for_authenticate( entityid=self._saml_idp_entityid, relay_state=client_redirect_url ) @@ -210,7 +221,7 @@ class SamlHandler(BaseHandler): return return await self._sso_handler.complete_sso_ui_auth_request( - self._auth_provider_id, + self.idp_id, remote_user_id, current_session.ui_auth_session_id, request, @@ -306,7 +317,7 @@ class SamlHandler(BaseHandler): return None await self._sso_handler.complete_sso_login_request( - self._auth_provider_id, + self.idp_id, remote_user_id, request, client_redirect_url, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 33cd6bc178..d8fb8cdd05 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -12,15 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional import attr -from typing_extensions import NoReturn +from typing_extensions import NoReturn, Protocol from twisted.web.http import Request -from synapse.api.errors import RedirectException, SynapseError +from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters @@ -40,6 +41,53 @@ class MappingException(Exception): """ +class SsoIdentityProvider(Protocol): + """Abstract base class to be implemented by SSO Identity Providers + + An Identity Provider, or IdP, is an external HTTP service which authenticates a user + to say whether they should be allowed to log in, or perform a given action. + + Synapse supports various implementations of IdPs, including OpenID Connect, SAML, + and CAS. + + The main entry point is `handle_redirect_request`, which should return a URI to + redirect the user's browser to the IdP's authentication page. + + Each IdP should be registered with the SsoHandler via + `hs.get_sso_handler().register_identity_provider()`, so that requests to + `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched. + """ + + @property + @abc.abstractmethod + def idp_id(self) -> str: + """A unique identifier for this SSO provider + + Eg, "saml", "cas", "github" + """ + + @abc.abstractmethod + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: + """Handle an incoming request to /login/sso/redirect + + Args: + request: the incoming HTTP request + client_redirect_url: the URL that we should redirect the + client to after login (or None for UI Auth). + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + URL to redirect to + """ + raise NotImplementedError() + + @attr.s class UserAttributes: # the localpart of the mxid that the mapper has assigned to the user. @@ -100,6 +148,14 @@ class SsoHandler: # a map from session id to session data self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + # map from idp_id to SsoIdentityProvider + self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + + def register_identity_provider(self, p: SsoIdentityProvider): + p_id = p.idp_id + assert p_id not in self._identity_providers + self._identity_providers[p_id] = p + def render_error( self, request: Request, @@ -124,6 +180,32 @@ class SsoHandler: ) respond_with_html(request, code, html) + async def handle_redirect_request( + self, request: SynapseRequest, client_redirect_url: bytes, + ) -> str: + """Handle a request to /login/sso/redirect + + Args: + request: incoming HTTP request + client_redirect_url: the URL that we should redirect the + client to after login. + + Returns: + the URI to redirect to + """ + if not self._identity_providers: + raise SynapseError( + 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED + ) + + # if we only have one auth provider, redirect to it directly + if len(self._identity_providers) == 1: + ap = next(iter(self._identity_providers.values())) + return await ap.handle_redirect_request(request, client_redirect_url) + + # otherwise, we have a configuration error + raise Exception("Multiple SSO identity providers have been configured!") + async def get_sso_user_by_remote_user_id( self, auth_provider_id: str, remote_user_id: str ) -> Optional[str]: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5f4c6703db..ebc346105b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet): return result -class BaseSSORedirectServlet(RestServlet): - """Common base class for /login/sso/redirect impls""" - +class SsoRedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + def __init__(self, hs: "HomeServer"): + # make sure that the relevant handlers are instantiated, so that they + # register themselves with the main SSOHandler. + if hs.config.cas_enabled: + hs.get_cas_handler() + elif hs.config.saml2_enabled: + hs.get_saml_handler() + elif hs.config.oidc_enabled: + hs.get_oidc_handler() + self._sso_handler = hs.get_sso_handler() + async def on_GET(self, request: SynapseRequest): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - sso_url = await self.get_sso_url(request, client_redirect_url) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding=None + ) + sso_url = await self._sso_handler.handle_redirect_request( + request, client_redirect_url + ) + logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) finish_request(request) - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - """Get the URL to redirect to, to perform SSO auth - - Args: - request: The client request to redirect. - client_redirect_url: the URL that we should redirect the - client to when everything is done - - Returns: - URL to redirect to - """ - # to be implemented by subclasses - raise NotImplementedError() - - -class CasRedirectServlet(BaseSSORedirectServlet): - def __init__(self, hs): - self._cas_handler = hs.get_cas_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._cas_handler.get_redirect_url( - {"redirectUrl": client_redirect_url} - ).encode("ascii") - class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) @@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet): ) -class SAMLRedirectServlet(BaseSSORedirectServlet): - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._saml_handler = hs.get_saml_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._saml_handler.handle_redirect_request(client_redirect_url) - - -class OIDCRedirectServlet(BaseSSORedirectServlet): - """Implementation for /login/sso/redirect for the OIDC login flow.""" - - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._oidc_handler = hs.get_oidc_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return await self._oidc_handler.handle_redirect_request( - request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: - CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - elif hs.config.saml2_enabled: - SAMLRedirectServlet(hs).register(http_server) - elif hs.config.oidc_enabled: - OIDCRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index fab077747f..9b9514632f 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -14,15 +14,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.handlers.sso import SsoIdentityProvider from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -35,7 +40,7 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -85,31 +90,20 @@ class AuthRestServlet(RestServlet): elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - if self._cas_enabled: - # Generate a request to CAS that redirects back to an endpoint - # to verify the successful authentication. - sso_redirect_url = self._cas_handler.get_redirect_url( - {"session": session}, - ) + if self._cas_enabled: + sso_auth_provider = self._cas_handler # type: SsoIdentityProvider elif self._saml_enabled: - # Some SAML identity providers (e.g. Google) require a - # RelayState parameter on requests. It is not necessary here, so - # pass in a dummy redirect URL (which will never get used). - client_redirect_url = b"unused" - sso_redirect_url = self._saml_handler.handle_redirect_request( - client_redirect_url, session - ) - + sso_auth_provider = self._saml_handler elif self._oidc_enabled: - client_redirect_url = b"" - sso_redirect_url = await self._oidc_handler.handle_redirect_request( - request, client_redirect_url, session - ) - + sso_auth_provider = self._oidc_handler else: raise SynapseError(400, "Homeserver not configured for SSO.") + sso_redirect_url = await sso_auth_provider.handle_redirect_request( + request, None, session + ) + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 18932d7518..999d628315 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -385,7 +385,7 @@ class CASTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, channel.result) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": -- cgit 1.5.1 From 9dde9c9f01ff8ed4c60314f10d97261739ea0547 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 5 Jan 2021 07:41:48 -0500 Subject: Implement MSC2176: Updated redaction rules (#8984) An experimental room version ("org.matrix.msc2176") contains the new redaction rules for testing. --- changelog.d/8984.feature | 1 + synapse/api/room_versions.py | 32 ++++++-- synapse/events/utils.py | 16 +++- synapse/handlers/room.py | 2 +- tests/events/test_utils.py | 185 ++++++++++++++++++++++++++++++++++++++----- 5 files changed, 206 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8984.feature (limited to 'tests') diff --git a/changelog.d/8984.feature b/changelog.d/8984.feature new file mode 100644 index 0000000000..4db629746e --- /dev/null +++ b/changelog.d/8984.feature @@ -0,0 +1 @@ +Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index f3ecbf36b6..de2cc15d33 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -51,11 +51,11 @@ class RoomDisposition: class RoomVersion: """An object which describes the unique attributes of a room version.""" - identifier = attr.ib() # str; the identifier for this version - disposition = attr.ib() # str; one of the RoomDispositions - event_format = attr.ib() # int; one of the EventFormatVersions - state_res = attr.ib() # int; one of the StateResolutionVersions - enforce_key_validity = attr.ib() # bool + identifier = attr.ib(type=str) # the identifier for this version + disposition = attr.ib(type=str) # one of the RoomDispositions + event_format = attr.ib(type=int) # one of the EventFormatVersions + state_res = attr.ib(type=int) # one of the StateResolutionVersions + enforce_key_validity = attr.ib(type=bool) # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool) @@ -64,9 +64,11 @@ class RoomVersion: # * Floats # * NaN, Infinity, -Infinity strict_canonicaljson = attr.ib(type=bool) - # bool: MSC2209: Check 'notifications' key while verifying + # MSC2209: Check 'notifications' key while verifying # m.room.power_levels auth rules. limit_notifications_power_levels = attr.ib(type=bool) + # MSC2174/MSC2176: Apply updated redaction rules algorithm. + msc2176_redaction_rules = attr.ib(type=bool) class RoomVersions: @@ -79,6 +81,7 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, + msc2176_redaction_rules=False, ) V2 = RoomVersion( "2", @@ -89,6 +92,7 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, + msc2176_redaction_rules=False, ) V3 = RoomVersion( "3", @@ -99,6 +103,7 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, + msc2176_redaction_rules=False, ) V4 = RoomVersion( "4", @@ -109,6 +114,7 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, + msc2176_redaction_rules=False, ) V5 = RoomVersion( "5", @@ -119,6 +125,7 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, + msc2176_redaction_rules=False, ) V6 = RoomVersion( "6", @@ -129,6 +136,18 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + ) + MSC2176 = RoomVersion( + "org.matrix.msc2176", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=True, ) @@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, + RoomVersions.MSC2176, ) } # type: Dict[str, RoomVersion] diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 14f7f1156f..9c22e33813 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: "state_key", "depth", "prev_events", - "prev_state", "auth_events", "origin", "origin_server_ts", - "membership", ] + # Room versions from before MSC2176 had additional allowed keys. + if not room_version.msc2176_redaction_rules: + allowed_keys.extend(["prev_state", "membership"]) + event_type = event_dict["type"] new_content = {} @@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: if event_type == EventTypes.Member: add_fields("membership") elif event_type == EventTypes.Create: + # MSC2176 rules state that create events cannot be redacted. + if room_version.msc2176_redaction_rules: + return event_dict + add_fields("creator") elif event_type == EventTypes.JoinRules: add_fields("join_rule") @@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: "kick", "redact", ) + + if room_version.msc2176_redaction_rules: + add_fields("invite") + elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") + elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: + add_fields("redacts") allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1f809fa161..3bece6d668 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -365,7 +365,7 @@ class RoomCreationHandler(BaseHandler): creation_content = { "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, - } + } # type: JsonDict # Check if old room was non-federatable diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index c1274c14af..8ba36c6074 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -34,11 +34,17 @@ def MockEvent(**kwargs): class PruneEventTestCase(unittest.TestCase): - """ Asserts that a new event constructed with `evdict` will look like - `matchdict` when it is redacted. """ - def run_test(self, evdict, matchdict, **kwargs): - self.assertEquals( + """ + Asserts that a new event constructed with `evdict` will look like + `matchdict` when it is redacted. + + Args: + evdict: The dictionary to build the event from. + matchdict: The expected resulting dictionary. + kwargs: Additional keyword arguments used to create the event. + """ + self.assertEqual( prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) @@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase): ) def test_basic_keys(self): + """Ensure that the keys that should be untouched are kept.""" + # Note that some of the values below don't really make sense, but the + # pruning of events doesn't worry about the values of any fields (with + # the exception of the content field). self.run_test( { + "event_id": "$3:domain", "type": "A", "room_id": "!1:domain", "sender": "@2:domain", - "event_id": "$3:domain", + "state_key": "B", + "content": {"other_key": "foo"}, + "hashes": "hashes", + "signatures": {"domain": {"algo:1": "sigs"}}, + "depth": 4, + "prev_events": "prev_events", + "prev_state": "prev_state", + "auth_events": "auth_events", "origin": "domain", + "origin_server_ts": 1234, + "membership": "join", + # Also include a key that should be removed. + "other_key": "foo", }, { + "event_id": "$3:domain", "type": "A", "room_id": "!1:domain", "sender": "@2:domain", - "event_id": "$3:domain", + "state_key": "B", + "hashes": "hashes", + "depth": 4, + "prev_events": "prev_events", + "prev_state": "prev_state", + "auth_events": "auth_events", "origin": "domain", + "origin_server_ts": 1234, + "membership": "join", "content": {}, - "signatures": {}, + "signatures": {"domain": {"algo:1": "sigs"}}, "unsigned": {}, }, ) - def test_unsigned_age_ts(self): + # As of MSC2176 we now redact the membership and prev_states keys. self.run_test( - {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}}, - { - "type": "B", - "event_id": "$test:domain", - "content": {}, - "signatures": {}, - "unsigned": {"age_ts": 20}, - }, + {"type": "A", "prev_state": "prev_state", "membership": "join"}, + {"type": "A", "content": {}, "signatures": {}, "unsigned": {}}, + room_version=RoomVersions.MSC2176, ) + def test_unsigned(self): + """Ensure that unsigned properties get stripped (except age_ts and replaces_state).""" self.run_test( { "type": "B", "event_id": "$test:domain", - "unsigned": {"other_key": "here"}, + "unsigned": { + "age_ts": 20, + "replaces_state": "$test2:domain", + "other_key": "foo", + }, }, { "type": "B", "event_id": "$test:domain", "content": {}, "signatures": {}, - "unsigned": {}, + "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"}, }, ) def test_content(self): + """The content dictionary should be stripped in most cases.""" self.run_test( {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, { @@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase): }, ) + # Some events keep a single content key/value. + EVENT_KEEP_CONTENT_KEYS = [ + ("member", "membership", "join"), + ("join_rules", "join_rule", "invite"), + ("history_visibility", "history_visibility", "shared"), + ] + for event_type, key, value in EVENT_KEEP_CONTENT_KEYS: + self.run_test( + { + "type": "m.room." + event_type, + "event_id": "$test:domain", + "content": {key: value, "other_key": "foo"}, + }, + { + "type": "m.room." + event_type, + "event_id": "$test:domain", + "content": {key: value}, + "signatures": {}, + "unsigned": {}, + }, + ) + + def test_create(self): + """Create events are partially redacted until MSC2176.""" self.run_test( { "type": "m.room.create", "event_id": "$test:domain", - "content": {"creator": "@2:domain", "other_field": "here"}, + "content": {"creator": "@2:domain", "other_key": "foo"}, }, { "type": "m.room.create", @@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase): }, ) + # After MSC2176, create events get nothing redacted. + self.run_test( + {"type": "m.room.create", "content": {"not_a_real_key": True}}, + { + "type": "m.room.create", + "content": {"not_a_real_key": True}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC2176, + ) + + def test_power_levels(self): + """Power level events keep a variety of content keys.""" + self.run_test( + { + "type": "m.room.power_levels", + "event_id": "$test:domain", + "content": { + "ban": 1, + "events": {"m.room.name": 100}, + "events_default": 2, + "invite": 3, + "kick": 4, + "redact": 5, + "state_default": 6, + "users": {"@admin:domain": 100}, + "users_default": 7, + "other_key": 8, + }, + }, + { + "type": "m.room.power_levels", + "event_id": "$test:domain", + "content": { + "ban": 1, + "events": {"m.room.name": 100}, + "events_default": 2, + # Note that invite is not here. + "kick": 4, + "redact": 5, + "state_default": 6, + "users": {"@admin:domain": 100}, + "users_default": 7, + }, + "signatures": {}, + "unsigned": {}, + }, + ) + + # After MSC2176, power levels events keep the invite key. + self.run_test( + {"type": "m.room.power_levels", "content": {"invite": 75}}, + { + "type": "m.room.power_levels", + "content": {"invite": 75}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC2176, + ) + def test_alias_event(self): """Alias events have special behavior up through room version 6.""" self.run_test( @@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase): }, ) - def test_msc2432_alias_event(self): - """After MSC2432, alias events have no special behavior.""" + # After MSC2432, alias events have no special behavior. self.run_test( {"type": "m.room.aliases", "content": {"aliases": ["test"]}}, { @@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase): room_version=RoomVersions.V6, ) + def test_redacts(self): + """Redaction events have no special behaviour until MSC2174/MSC2176.""" + + self.run_test( + {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}}, + { + "type": "m.room.redaction", + "content": {}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.V6, + ) + + # After MSC2174, redaction events keep the redacts content key. + self.run_test( + {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}}, + { + "type": "m.room.redaction", + "content": {"redacts": "$test2:domain"}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC2176, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): -- cgit 1.5.1 From 1b4d5d6acf8cfbe65601b881360ea730f9693d80 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 6 Jan 2021 12:33:20 -0500 Subject: Empty iterables should count towards cache usage. (#9028) --- changelog.d/9028.bugfix | 1 + synapse/util/caches/deferred_cache.py | 2 +- tests/util/caches/test_deferred_cache.py | 73 ++++++++++++++++++++++---------- 3 files changed, 52 insertions(+), 24 deletions(-) create mode 100644 changelog.d/9028.bugfix (limited to 'tests') diff --git a/changelog.d/9028.bugfix b/changelog.d/9028.bugfix new file mode 100644 index 0000000000..66666886a4 --- /dev/null +++ b/changelog.d/9028.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where some caches could grow larger than configured. diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 601305487c..1adc92eb90 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]): keylen=keylen, cache_name=name, cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, + size_callback=(lambda d: len(d) or 1) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, ) # type: LruCache[KT, VT] diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index dadfabd46d..ecd9efc4df 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -25,13 +25,8 @@ from tests.unittest import TestCase class DeferredCacheTestCase(TestCase): def test_empty(self): cache = DeferredCache("test") - failed = False - try: + with self.assertRaises(KeyError): cache.get("foo") - except KeyError: - failed = True - - self.assertTrue(failed) def test_hit(self): cache = DeferredCache("test") @@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase): cache.prefill(("foo",), 123) cache.invalidate(("foo",)) - failed = False - try: + with self.assertRaises(KeyError): cache.get(("foo",)) - except KeyError: - failed = True - - self.assertTrue(failed) def test_invalidate_all(self): cache = DeferredCache("testcache") @@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase): cache.prefill(2, "two") cache.prefill(3, "three") # 1 will be evicted - failed = False - try: + with self.assertRaises(KeyError): cache.get(1) - except KeyError: - failed = True - - self.assertTrue(failed) cache.get(2) cache.get(3) @@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase): cache.prefill(3, "three") - failed = False - try: + with self.assertRaises(KeyError): cache.get(2) - except KeyError: - failed = True - self.assertTrue(failed) + cache.get(1) + cache.get(3) + + def test_eviction_iterable(self): + cache = DeferredCache( + "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True, + ) + + cache.prefill(1, ["one", "two"]) + cache.prefill(2, ["three"]) + # Now access 1 again, thus causing 2 to be least-recently used + cache.get(1) + + # Now add an item to the cache, which evicts 2. + cache.prefill(3, ["four"]) + with self.assertRaises(KeyError): + cache.get(2) + + # Ensure 1 & 3 are in the cache. cache.get(1) cache.get(3) + + # Now access 1 again, thus causing 3 to be least-recently used + cache.get(1) + + # Now add an item with multiple elements to the cache + cache.prefill(4, ["five", "six"]) + + # Both 1 and 3 are evicted since there's too many elements. + with self.assertRaises(KeyError): + cache.get(1) + with self.assertRaises(KeyError): + cache.get(3) + + # Now add another item to fill the cache again. + cache.prefill(5, ["seven"]) + + # Now access 4, thus causing 5 to be least-recently used + cache.get(4) + + # Add an empty item. + cache.prefill(6, []) + + # 5 gets evicted and replaced since an empty element counts as an item. + with self.assertRaises(KeyError): + cache.get(5) + cache.get(4) + cache.get(6) -- cgit 1.5.1 From 8d3d264052adffaf9ef36a6d9235aeeb8bef5fb5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 7 Jan 2021 11:41:28 +0000 Subject: Skip unit tests which require optional dependencies (#9031) If we are lacking an optional dependency, skip the tests that rely on it. --- changelog.d/9031.misc | 1 + tests/handlers/test_oidc.py | 19 ++++++++++++++++++- tests/rest/client/v1/test_login.py | 11 ++++++++++- tests/rest/client/v2_alpha/test_auth.py | 26 ++++++++++++++++---------- tests/rest/media/v1/test_url_preview.py | 7 +++++++ tests/test_preview.py | 11 +++++++++++ tests/unittest.py | 28 +++++++++++++++++++++++++++- 7 files changed, 90 insertions(+), 13 deletions(-) create mode 100644 changelog.d/9031.misc (limited to 'tests') diff --git a/changelog.d/9031.misc b/changelog.d/9031.misc new file mode 100644 index 0000000000..f43611c385 --- /dev/null +++ b/changelog.d/9031.misc @@ -0,0 +1 @@ +Fix running unit tests when optional dependencies are not installed. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 368d600b33..f5df657814 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -24,7 +24,6 @@ import pymacaroons from twisted.web.resource import Resource from synapse.api.errors import RedirectException -from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException from synapse.rest.client.v1 import login from synapse.rest.synapse.client.pick_username import pick_username_resource @@ -34,6 +33,14 @@ from synapse.types import UserID from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config +try: + import authlib # noqa: F401 + + HAS_OIDC = True +except ImportError: + HAS_OIDC = False + + # These are a few constants that are used as config parameters in the tests. ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" @@ -113,6 +120,9 @@ async def get_json(url): class OidcHandlerTestCase(HomeserverTestCase): + if not HAS_OIDC: + skip = "requires OIDC" + def default_config(self): config = super().default_config() config["public_baseurl"] = BASE_URL @@ -458,6 +468,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("fetch_error") # Handle code exchange failure + from synapse.handlers.oidc_handler import OidcError + self.handler._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) @@ -538,6 +550,8 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "foo", "error_description": "bar"}', ) ) + from synapse.handlers.oidc_handler import OidcError + exc = self.get_failure(self.handler._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -829,6 +843,9 @@ class OidcHandlerTestCase(HomeserverTestCase): class UsernamePickerTestCase(HomeserverTestCase): + if not HAS_OIDC: + skip = "requires OIDC" + servlets = [login.register_servlets] def default_config(self): diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 999d628315..901c72d36a 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -4,7 +4,10 @@ import urllib.parse from mock import Mock -import jwt +try: + import jwt +except ImportError: + jwt = None import synapse.rest.admin from synapse.appservice import ApplicationService @@ -460,6 +463,9 @@ class CASTestCase(unittest.HomeserverTestCase): class JWTTestCase(unittest.HomeserverTestCase): + if not jwt: + skip = "requires jwt" + servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -628,6 +634,9 @@ class JWTTestCase(unittest.HomeserverTestCase): # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. class JWTPubKeyTestCase(unittest.HomeserverTestCase): + if not jwt: + skip = "requires jwt" + servlets = [ login.register_servlets, ] diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index ac66a4e0b7..bb91e0c331 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -26,8 +26,10 @@ from synapse.rest.oidc import OIDCResource from synapse.types import JsonDict, UserID from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC from tests.rest.client.v1.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel +from tests.unittest import override_config, skip_unless class DummyRecaptchaChecker(UserInteractiveAuthChecker): @@ -158,20 +160,22 @@ class UIAuthTests(unittest.HomeserverTestCase): def default_config(self): config = super().default_config() + config["public_baseurl"] = "https://synapse.test" - # we enable OIDC as a way of testing SSO flows - oidc_config = {} - oidc_config.update(TEST_OIDC_CONFIG) - oidc_config["allow_existing_users"] = True + if HAS_OIDC: + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True + config["oidc_config"] = oidc_config - config["oidc_config"] = oidc_config - config["public_baseurl"] = "https://synapse.test" return config def create_resource_dict(self): resource_dict = super().create_resource_dict() - # mount the OIDC resource at /_synapse/oidc - resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + if HAS_OIDC: + # mount the OIDC resource at /_synapse/oidc + resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) return resource_dict def prepare(self, reactor, clock, hs): @@ -380,6 +384,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # Note that *no auth* information is provided, not even a session iD! self.delete_device(self.user_tok, self.device_id, 200) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self): login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] @@ -393,13 +399,13 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - # now call the device deletion API: we should get the option to auth with SSO - # and not password. channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self): """A user that had a password and then logged in with SSO should get both flows """ diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 83d728b4a4..6968502433 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -26,8 +26,15 @@ from twisted.test.proto_helpers import AccumulatingProtocol from tests import unittest from tests.server import FakeTransport +try: + import lxml +except ImportError: + lxml = None + class URLPreviewTests(unittest.HomeserverTestCase): + if not lxml: + skip = "url preview feature requires lxml" hijack_auth = True user_id = "@test:user" diff --git a/tests/test_preview.py b/tests/test_preview.py index a883d707df..c19facc1cb 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -20,8 +20,16 @@ from synapse.rest.media.v1.preview_url_resource import ( from . import unittest +try: + import lxml +except ImportError: + lxml = None + class PreviewTestCase(unittest.TestCase): + if not lxml: + skip = "url preview feature requires lxml" + def test_long_summarize(self): example_paras = [ """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: @@ -137,6 +145,9 @@ class PreviewTestCase(unittest.TestCase): class PreviewUrlTestCase(unittest.TestCase): + if not lxml: + skip = "url preview feature requires lxml" + def test_simple(self): html = """ diff --git a/tests/unittest.py b/tests/unittest.py index af7f752c5a..bbd295687c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -736,3 +736,29 @@ def override_config(extra_config): return func return decorator + + +TV = TypeVar("TV") + + +def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: + """A test decorator which will skip the decorated test unless a condition is set + + For example: + + class MyTestCase(TestCase): + @skip_unless(HAS_FOO, "Cannot test without foo") + def test_foo(self): + ... + + Args: + condition: If true, the test will be skipped + reason: the reason to give for skipping the test + """ + + def decorator(f: TV) -> TV: + if not condition: + f.skip = reason # type: ignore + return f + + return decorator -- cgit 1.5.1 From 3fc2399dbe0fa141e3bac4b45686eceb9f8d086e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 7 Jan 2021 12:16:23 +0000 Subject: black-format tests/rest/client/v1/test_login.py black seems to want to reformat this, despite `black --check` being happy with it :/ --- tests/rest/client/v1/test_login.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 901c72d36a..3b82511d97 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -413,8 +413,7 @@ class CASTestCase(unittest.HomeserverTestCase): } ) def test_cas_redirect_whitelisted(self): - """Tests that the SSO login flow serves a redirect to a whitelisted url - """ + """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) @@ -773,8 +772,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs def test_login_appservice_user(self): - """Test that an appservice user can use /login - """ + """Test that an appservice user can use /login""" self.register_as_user(AS_USER) params = { @@ -788,8 +786,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self): - """Test that the appservice bot can use /login - """ + """Test that the appservice bot can use /login""" self.register_as_user(AS_USER) params = { @@ -803,8 +800,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self): - """Test that non-as users cannot login with the as token - """ + """Test that non-as users cannot login with the as token""" self.register_as_user(AS_USER) params = { @@ -818,8 +814,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self): - """Test that as users cannot login with wrong as token - """ + """Test that as users cannot login with wrong as token""" self.register_as_user(AS_USER) params = { @@ -834,7 +829,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_no_token(self): """Test that users must provide a token when using the appservice - login method + login method """ self.register_as_user(AS_USER) -- cgit 1.5.1 From 23d701864fedbc40863b34c9c46c134295dd0a35 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 7 Jan 2021 08:03:38 -0500 Subject: Improve the performance of calculating ignored users in large rooms (#9024) This allows for efficiently finding which users ignore a particular user. Co-authored-by: Erik Johnston --- changelog.d/9024.feature | 1 + synapse/push/bulk_push_rule_evaluator.py | 12 +- synapse/storage/databases/main/account_data.py | 121 ++++++++++++++++----- .../main/schema/delta/59/01ignored_user.py | 82 ++++++++++++++ synapse/storage/prepare_database.py | 2 +- tests/storage/test_account_data.py | 120 ++++++++++++++++++++ 6 files changed, 304 insertions(+), 34 deletions(-) create mode 100644 changelog.d/9024.feature create mode 100644 synapse/storage/databases/main/schema/delta/59/01ignored_user.py create mode 100644 tests/storage/test_account_data.py (limited to 'tests') diff --git a/changelog.d/9024.feature b/changelog.d/9024.feature new file mode 100644 index 0000000000..073dafbf83 --- /dev/null +++ b/changelog.d/9024.feature @@ -0,0 +1 @@ +Improved performance when calculating ignored users in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 10f27e4378..9018f9e20b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -203,14 +203,18 @@ class BulkPushRuleEvaluator: condition_cache = {} # type: Dict[str, bool] + # If the event is not a state event check if any users ignore the sender. + if not event.is_state(): + ignorers = await self.store.ignored_by(event.sender) + else: + ignorers = set() + for uid, rules in rules_by_user.items(): if event.sender == uid: continue - if not event.is_state(): - is_ignored = await self.store.is_ignored_by(event.sender, uid) - if is_ignored: - continue + if uid in ignorers: + continue display_name = None profile_info = room_members.get(uid) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 49ee23470d..bff51e92b9 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,7 +16,7 @@ import abc import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json @@ -24,7 +24,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import _CacheContext, cached +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -287,23 +287,25 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) - @cached(num_args=2, cache_context=True, max_entries=5000) - async def is_ignored_by( - self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext - ) -> bool: - ignored_account_data = await self.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, - ignorer_user_id, - on_invalidate=cache_context.invalidate, - ) - if not ignored_account_data: - return False + @cached(max_entries=5000, iterable=True) + async def ignored_by(self, user_id: str) -> Set[str]: + """ + Get users which ignore the given user. - try: - return ignored_user_id in ignored_account_data.get("ignored_users", {}) - except TypeError: - # The type of the ignored_users field is invalid. - return False + Params: + user_id: The user ID which might be ignored. + + Return: + The user IDs which ignore the given user. + """ + return set( + await self.db_pool.simple_select_onecol( + table="ignored_users", + keyvalues={"ignored_user_id": user_id}, + retcol="ignorer_user_id", + desc="ignored_by", + ) + ) class AccountDataStore(AccountDataWorkerStore): @@ -390,18 +392,14 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ - content_json = json_encoder.encode(content) - async with self._account_data_id_gen.get_next() as next_id: - # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so simple_upsert will retry if - # there is a conflict. - await self.db_pool.simple_upsert( - desc="add_user_account_data", - table="account_data", - keyvalues={"user_id": user_id, "account_data_type": account_data_type}, - values={"stream_id": next_id, "content": content_json}, - lock=False, + await self.db_pool.runInteraction( + "add_user_account_data", + self._add_account_data_for_user, + next_id, + user_id, + account_data_type, + content, ) # it's theoretically possible for the above to succeed and the @@ -424,6 +422,71 @@ class AccountDataStore(AccountDataWorkerStore): return self._account_data_id_gen.get_current_token() + def _add_account_data_for_user( + self, + txn, + next_id: int, + user_id: str, + account_data_type: str, + content: JsonDict, + ) -> None: + content_json = json_encoder.encode(content) + + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so simple_upsert will retry if + # there is a conflict. + self.db_pool.simple_upsert_txn( + txn, + table="account_data", + keyvalues={"user_id": user_id, "account_data_type": account_data_type}, + values={"stream_id": next_id, "content": content_json}, + lock=False, + ) + + # Ignored users get denormalized into a separate table as an optimisation. + if account_data_type != AccountDataTypes.IGNORED_USER_LIST: + return + + # Insert / delete to sync the list of ignored users. + previously_ignored_users = set( + self.db_pool.simple_select_onecol_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + ) + ) + + # If the data is invalid, no one is ignored. + ignored_users_content = content.get("ignored_users", {}) + if isinstance(ignored_users_content, dict): + currently_ignored_users = set(ignored_users_content) + else: + currently_ignored_users = set() + + # Delete entries which are no longer ignored. + self.db_pool.simple_delete_many_txn( + txn, + table="ignored_users", + column="ignored_user_id", + iterable=previously_ignored_users - currently_ignored_users, + keyvalues={"ignorer_user_id": user_id}, + ) + + # Add entries which are newly ignored. + self.db_pool.simple_insert_many_txn( + txn, + table="ignored_users", + values=[ + {"ignorer_user_id": user_id, "ignored_user_id": u} + for u in currently_ignored_users - previously_ignored_users + ], + ) + + # Invalidate the cache for any ignored users which were added or removed. + for ignored_user_id in previously_ignored_users ^ currently_ignored_users: + self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + async def _update_max_stream_id(self, next_id: int) -> None: """Update the max stream_id diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py new file mode 100644 index 0000000000..f35c70b699 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py @@ -0,0 +1,82 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This migration denormalises the account_data table into an ignored users table. +""" + +import logging +from io import StringIO + +from synapse.storage._base import db_to_json +from synapse.storage.engines import BaseDatabaseEngine +from synapse.storage.prepare_database import execute_statements_from_stream +from synapse.storage.types import Cursor + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + pass + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + logger.info("Creating ignored_users table") + execute_statements_from_stream(cur, StringIO(_create_commands)) + + # We now upgrade existing data, if any. We don't do this in `run_upgrade` as + # we a) want to run these before adding constraints and b) `run_upgrade` is + # not run on empty databases. + insert_sql = """ + INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?) + """ + + logger.info("Converting existing ignore lists") + cur.execute( + "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'" + ) + for user_id, content_json in cur.fetchall(): + content = db_to_json(content_json) + + # The content should be the form of a dictionary with a key + # "ignored_users" pointing to a dictionary with keys of ignored users. + # + # { "ignored_users": "@someone:example.org": {} } + ignored_users = content.get("ignored_users", {}) + if isinstance(ignored_users, dict) and ignored_users: + cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) + + # Add indexes after inserting data for efficiency. + logger.info("Adding constraints to ignored_users table") + execute_statements_from_stream(cur, StringIO(_constraints_commands)) + + +# there might be duplicates, so the easiest way to achieve this is to create a new +# table with the right data, and renaming it into place + +_create_commands = """ +-- Users which are ignored when calculating push notifications. This data is +-- denormalized from account data. +CREATE TABLE IF NOT EXISTS ignored_users( + ignorer_user_id TEXT NOT NULL, -- The user ID of the user who is ignoring another user. (This is a local user.) + ignored_user_id TEXT NOT NULL -- The user ID of the user who is being ignored. (This is a local or remote user.) +); +""" + +_constraints_commands = """ +CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id); + +-- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user. +CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id); +""" diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6684403a0a..01efb2cabb 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) # XXX: If you're about to bump this to 59 (or higher) please create an update # that drops the unused `cache_invalidation_stream` table, as per #7436! # XXX: Also add an update to drop `account_data_max_stream_id` as per #7656! -SCHEMA_VERSION = 58 +SCHEMA_VERSION = 59 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py new file mode 100644 index 0000000000..673e1fe3e3 --- /dev/null +++ b/tests/storage/test_account_data.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Set + +from synapse.api.constants import AccountDataTypes + +from tests import unittest + + +class IgnoredUsersTestCase(unittest.HomeserverTestCase): + def prepare(self, hs, reactor, clock): + self.store = self.hs.get_datastore() + self.user = "@user:test" + + def _update_ignore_list( + self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None + ) -> None: + """Update the account data to block the given users.""" + if ignorer_user_id is None: + ignorer_user_id = self.user + + self.get_success( + self.store.add_account_data_for_user( + ignorer_user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {u: {} for u in ignored_user_ids}}, + ) + ) + + def assert_ignorers( + self, ignored_user_id: str, expected_ignorer_user_ids: Set[str] + ) -> None: + self.assertEqual( + self.get_success(self.store.ignored_by(ignored_user_id)), + expected_ignorer_user_ids, + ) + + def test_ignoring_users(self): + """Basic adding/removing of users from the ignore list.""" + self._update_ignore_list("@other:test", "@another:remote") + + # Check a user which no one ignores. + self.assert_ignorers("@user:test", set()) + + # Check a local user which is ignored. + self.assert_ignorers("@other:test", {self.user}) + + # Check a remote user which is ignored. + self.assert_ignorers("@another:remote", {self.user}) + + # Add one user, remove one user, and leave one user. + self._update_ignore_list("@foo:test", "@another:remote") + + # Check the removed user. + self.assert_ignorers("@other:test", set()) + + # Check the added user. + self.assert_ignorers("@foo:test", {self.user}) + + # Check the removed user. + self.assert_ignorers("@another:remote", {self.user}) + + def test_caching(self): + """Ensure that caching works properly between different users.""" + # The first user ignores a user. + self._update_ignore_list("@other:test") + self.assert_ignorers("@other:test", {self.user}) + + # The second user ignores them. + self._update_ignore_list("@other:test", ignorer_user_id="@second:test") + self.assert_ignorers("@other:test", {self.user, "@second:test"}) + + # The first user un-ignores them. + self._update_ignore_list() + self.assert_ignorers("@other:test", {"@second:test"}) + + def test_invalid_data(self): + """Invalid data ends up clearing out the ignored users list.""" + # Add some data and ensure it is there. + self._update_ignore_list("@other:test") + self.assert_ignorers("@other:test", {self.user}) + + # No ignored_users key. + self.get_success( + self.store.add_account_data_for_user( + self.user, AccountDataTypes.IGNORED_USER_LIST, {}, + ) + ) + + # No one ignores the user now. + self.assert_ignorers("@other:test", set()) + + # Add some data and ensure it is there. + self._update_ignore_list("@other:test") + self.assert_ignorers("@other:test", {self.user}) + + # Invalid data. + self.get_success( + self.store.add_account_data_for_user( + self.user, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": "unexpected"}, + ) + ) + + # No one ignores the user now. + self.assert_ignorers("@other:test", set()) -- cgit 1.5.1 From bbd04441edc414c161748e1499961c6a68a460aa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 Jan 2021 15:51:18 +0000 Subject: Fix type hints in test_login.py --- mypy.ini | 1 + tests/rest/client/v1/test_login.py | 78 ++++++++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/mypy.ini b/mypy.ini index 5d15b7bf1c..b996867121 100644 --- a/mypy.ini +++ b/mypy.ini @@ -103,6 +103,7 @@ files = tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, + tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_stream_change_cache.py diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 3b82511d97..b5a2ac23d4 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,14 +1,24 @@ -import json +# -*- coding: utf-8 -*- +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time import urllib.parse +from typing import Any, Dict, Union from mock import Mock -try: - import jwt -except ImportError: - jwt = None - import synapse.rest.admin from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout @@ -16,7 +26,33 @@ from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from tests import unittest -from tests.unittest import override_config +from tests.unittest import override_config, skip_unless + +try: + import jwt + + HAS_JWT = True +except ImportError: + HAS_JWT = False + + +# public_base_url used in some tests +BASE_URL = "https://synapse/" + +# CAS server used in some tests +CAS_SERVER = "https://fake.test" + +# just enough to tell pysaml2 where to redirect to +SAML_SERVER = "https://test.saml.server/idp/sso" +TEST_SAML_METADATA = """ + + + + + +""" % { + "SAML_SERVER": SAML_SERVER, +} LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" @@ -461,10 +497,8 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) +@skip_unless(HAS_JWT, "requires jwt") class JWTTestCase(unittest.HomeserverTestCase): - if not jwt: - skip = "requires jwt" - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -480,17 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(token, secret, self.jwt_algorithm) + result = jwt.encode( + payload, secret, self.jwt_algorithm + ) # type: Union[str, bytes] if isinstance(result, bytes): return result.decode("ascii") return result def jwt_login(self, *args): - params = json.dumps( - {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - ) + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel @@ -622,7 +656,7 @@ class JWTTestCase(unittest.HomeserverTestCase): ) def test_login_no_token(self): - params = json.dumps({"type": "org.matrix.login.jwt"}) + params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -632,10 +666,8 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. +@skip_unless(HAS_JWT, "requires jwt") class JWTPubKeyTestCase(unittest.HomeserverTestCase): - if not jwt: - skip = "requires jwt" - servlets = [ login.register_servlets, ] @@ -692,17 +724,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(token, secret, "RS256") + result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] if isinstance(result, bytes): return result.decode("ascii") return result def jwt_login(self, *args): - params = json.dumps( - {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - ) + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel -- cgit 1.5.1 From 8a910f97a47edb398acba2ad87805e4593c76139 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 Jan 2021 16:16:16 +0000 Subject: Add some tests for the IDP picker flow --- synapse/rest/client/v1/login.py | 4 +- tests/rest/client/v1/test_login.py | 191 ++++++++++++++++++++++++++++++++++++- tests/rest/client/v1/utils.py | 3 +- 3 files changed, 193 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ebc346105b..be938df962 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet): # register themselves with the main SSOHandler. if hs.config.cas_enabled: hs.get_cas_handler() - elif hs.config.saml2_enabled: + if hs.config.saml2_enabled: hs.get_saml_handler() - elif hs.config.oidc_enabled: + if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index b5a2ac23d4..1d1dc9f8a2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,17 +15,26 @@ import time import urllib.parse -from typing import Any, Dict, Union +from html.parser import HTMLParser +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from mock import Mock +import pymacaroons + +from twisted.web.resource import Resource + import synapse.rest.admin from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.synapse.client.pick_idp import PickIdpResource from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG from tests.unittest import override_config, skip_unless try: @@ -350,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + config["oidc_config"] = TEST_OIDC_CONFIG + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + return d + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + client_redirect_url = "https://x?" + + # first hit the redirect url, which should redirect to our idp picker + channel = self.make_request( + "GET", + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + class FormPageParser(HTMLParser): + def __init__(self): + super().__init__() + + # the values of the hidden inputs: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of the radio buttons + self.radios = [] # type: List[Optional[str]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "input": + if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": + self.radios.append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + input_name = attr_dict["name"] + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + self.fail(message) + + p = FormPageParser() + p.feed(channel.result["body"].decode("utf-8")) + p.close() + + self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + + self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + cas_uri = channel.headers.getRawHeaders("Location")[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + saml_uri = channel.headers.getRawHeaders("Location")[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, client_redirect_url) + + def test_multi_sso_redirect_to_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + client_redirect_url = "https://x?" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookies = dict( + h.split(";")[0].split("=", maxsplit=1) + for h in channel.headers.getRawHeaders("Set-Cookie") + ) + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + client_redirect_url, + ) + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + class CASTestCase(unittest.HomeserverTestCase): servlets = [ @@ -363,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase): config = self.default_config() config["cas_config"] = { "enabled": True, - "server_url": "https://fake.test", + "server_url": CAS_SERVER, "service_url": "https://matrix.goodserver.com:8448", } diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index dbc27893b5..81b7f84360 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -444,6 +444,7 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = { "client_id": "test-client-id", "client_secret": "test-client-secret", "scopes": ["profile"], - "authorization_endpoint": "https://z", + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, "token_endpoint": "https://issuer.test/token", "userinfo_endpoint": "https://issuer.test/userinfo", "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -- cgit 1.5.1 From d32870ffa5a2353d93e5723787d5f4dcbf14b32d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 8 Jan 2021 14:23:04 +0000 Subject: Fix validate_config on nested objects (#9054) --- changelog.d/9054.bugfix | 1 + synapse/config/_util.py | 2 +- tests/config/test_util.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9054.bugfix create mode 100644 tests/config/test_util.py (limited to 'tests') diff --git a/changelog.d/9054.bugfix b/changelog.d/9054.bugfix new file mode 100644 index 0000000000..0bfe951f17 --- /dev/null +++ b/changelog.d/9054.bugfix @@ -0,0 +1 @@ +Fix a minor bug which could cause confusing error messages from invalid configurations. diff --git a/synapse/config/_util.py b/synapse/config/_util.py index 1bbe83c317..8fce7f6bb1 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -56,7 +56,7 @@ def json_error_to_config_error( """ # copy `config_path` before modifying it. path = list(config_path) - for p in list(e.path): + for p in list(e.absolute_path): if isinstance(p, int): path.append("" % p) else: diff --git a/tests/config/test_util.py b/tests/config/test_util.py new file mode 100644 index 0000000000..10363e3765 --- /dev/null +++ b/tests/config/test_util.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config import ConfigError +from synapse.config._util import validate_config + +from tests.unittest import TestCase + + +class ValidateConfigTestCase(TestCase): + """Test cases for synapse.config._util.validate_config""" + + def test_bad_object_in_array(self): + """malformed objects within an array should be validated correctly""" + + # consider a structure: + # + # array_of_objs: + # - r: 1 + # foo: 2 + # + # - r: 2 + # bar: 3 + # + # ... where each entry must contain an "r": check that the path + # to the required item is correclty reported. + + schema = { + "type": "object", + "properties": { + "array_of_objs": { + "type": "array", + "items": {"type": "object", "required": ["r"]}, + }, + }, + } + + with self.assertRaises(ConfigError) as c: + validate_config(schema, {"array_of_objs": [{}]}, ("base",)) + + self.assertEqual(c.exception.path, ["base", "array_of_objs", ""]) -- cgit 1.5.1 From 1315a2e8be702a513d49c1142e9e52b642286635 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 11 Jan 2021 16:09:22 +0000 Subject: Use a chain cover index to efficiently calculate auth chain difference (#8868) --- changelog.d/8868.misc | 1 + docs/auth_chain_diff.dot | 32 ++ docs/auth_chain_diff.dot.png | Bin 0 -> 42427 bytes docs/auth_chain_difference_algorithm.md | 108 +++++ synapse/storage/database.py | 22 +- synapse/storage/databases/main/event_federation.py | 185 +++++++ synapse/storage/databases/main/events.py | 535 ++++++++++++++++++++- synapse/storage/databases/main/room.py | 51 +- .../main/schema/delta/59/04_event_auth_chains.sql | 52 ++ .../delta/59/04_event_auth_chains.sql.postgres | 16 + synapse/util/iterutils.py | 53 +- tests/storage/test_event_chain.py | 472 ++++++++++++++++++ tests/storage/test_event_federation.py | 249 +++++++++- tests/util/test_itertools.py | 41 +- 14 files changed, 1769 insertions(+), 48 deletions(-) create mode 100644 changelog.d/8868.misc create mode 100644 docs/auth_chain_diff.dot create mode 100644 docs/auth_chain_diff.dot.png create mode 100644 docs/auth_chain_difference_algorithm.md create mode 100644 synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres create mode 100644 tests/storage/test_event_chain.py (limited to 'tests') diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc new file mode 100644 index 0000000000..1a11e30944 --- /dev/null +++ b/changelog.d/8868.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions for new rooms. diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot new file mode 100644 index 0000000000..978d579ada --- /dev/null +++ b/docs/auth_chain_diff.dot @@ -0,0 +1,32 @@ +digraph auth { + nodesep=0.5; + rankdir="RL"; + + C [label="Create (1,1)"]; + + BJ [label="Bob's Join (2,1)", color=red]; + BJ2 [label="Bob's Join (2,2)", color=red]; + BJ2 -> BJ [color=red, dir=none]; + + subgraph cluster_foo { + A1 [label="Alice's invite (4,1)", color=blue]; + A2 [label="Alice's Join (4,2)", color=blue]; + A3 [label="Alice's Join (4,3)", color=blue]; + A3 -> A2 -> A1 [color=blue, dir=none]; + color=none; + } + + PL1 [label="Power Level (3,1)", color=darkgreen]; + PL2 [label="Power Level (3,2)", color=darkgreen]; + PL2 -> PL1 [color=darkgreen, dir=none]; + + {rank = same; C; BJ; PL1; A1;} + + A1 -> C [color=grey]; + A1 -> BJ [color=grey]; + PL1 -> C [color=grey]; + BJ2 -> PL1 [penwidth=2]; + + A3 -> PL2 [penwidth=2]; + A1 -> PL1 -> BJ -> C [penwidth=2]; +} diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png new file mode 100644 index 0000000000..771c07308f Binary files /dev/null and b/docs/auth_chain_diff.dot.png differ diff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md new file mode 100644 index 0000000000..30f72a70da --- /dev/null +++ b/docs/auth_chain_difference_algorithm.md @@ -0,0 +1,108 @@ +# Auth Chain Difference Algorithm + +The auth chain difference algorithm is used by V2 state resolution, where a +naive implementation can be a significant source of CPU and DB usage. + +### Definitions + +A *state set* is a set of state events; e.g. the input of a state resolution +algorithm is a collection of state sets. + +The *auth chain* of a set of events are all the events' auth events and *their* +auth events, recursively (i.e. the events reachable by walking the graph induced +by an event's auth events links). + +The *auth chain difference* of a collection of state sets is the union minus the +intersection of the sets of auth chains corresponding to the state sets, i.e an +event is in the auth chain difference if it is reachable by walking the auth +event graph from at least one of the state sets but not from *all* of the state +sets. + +## Breadth First Walk Algorithm + +A way of calculating the auth chain difference without calculating the full auth +chains for each state set is to do a parallel breadth first walk (ordered by +depth) of each state set's auth chain. By tracking which events are reachable +from each state set we can finish early if every pending event is reachable from +every state set. + +This can work well for state sets that have a small auth chain difference, but +can be very inefficient for larger differences. However, this algorithm is still +used if we don't have a chain cover index for the room (e.g. because we're in +the process of indexing it). + +## Chain Cover Index + +Synapse computes auth chain differences by pre-computing a "chain cover" index +for the auth chain in a room, allowing efficient reachability queries like "is +event A in the auth chain of event B". This is done by assigning every event a +*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links* +between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A` +is in the auth chain of `B`) if and only if either: + +1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s + sequence number; or +2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that + `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`. + +There are actually two potential implementations, one where we store links from +each chain to every other reachable chain (the transitive closure of the links +graph), and one where we remove redundant links (the transitive reduction of the +links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1` +would not be stored. Synapse uses the former implementations so that it doesn't +need to recurse to test reachability between chains. + +### Example + +An example auth graph would look like the following, where chains have been +formed based on type/state_key and are denoted by colour and are labelled with +`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey +are those that would be remove in the second implementation described above). + +![Example](auth_chain_diff.dot.png) + +Note that we don't include all links between events and their auth events, as +most of those links would be redundant. For example, all events point to the +create event, but each chain only needs the one link from it's base to the +create event. + +## Using the Index + +This index can be used to calculate the auth chain difference of the state sets +by looking at the chain ID and sequence numbers reachable from each state set: + +1. For every state set lookup the chain ID/sequence numbers of each state event +2. Use the index to find all chains and the maximum sequence number reachable + from each state set. +3. The auth chain difference is then all events in each chain that have sequence + numbers between the maximum sequence number reachable from *any* state set and + the minimum reachable by *all* state sets (if any). + +Note that steps 2 is effectively calculating the auth chain for each state set +(in terms of chain IDs and sequence numbers), and step 3 is calculating the +difference between the union and intersection of the auth chains. + +### Worked Example + +For example, given the above graph, we can calculate the difference between +state sets consisting of: + +1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and +2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`. + +Using the index we see that the following auth chains are reachable from each +state set: + +1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)` +2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)` + +And so, for each the ranges that are in the auth chain difference: +1. Chain 1: None, (since everything can reach the create event). +2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state + sets and the maximum reachable is `2` (corresponding to Bob's second join). +3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power + level). +4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins). + +So the final result is: Bob's second join `(2,2)`, the second power level +`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b70ca3087b..6cfadc2b4e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -179,6 +179,9 @@ class LoggingDatabaseConnection: _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] +R = TypeVar("R") + + class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -266,6 +269,20 @@ class LoggingTransaction: for val in args: self.execute(sql, val) + def execute_values(self, sql: str, *args: Any) -> List[Tuple]: + """Corresponds to psycopg2.extras.execute_values. Only available when + using postgres. + + Always sets fetch=True when caling `execute_values`, so will return the + results. + """ + assert isinstance(self.database_engine, PostgresEngine) + from psycopg2.extras import execute_values # type: ignore + + return self._do_execute( + lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args + ) + def execute(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.execute, sql, *args) @@ -276,7 +293,7 @@ class LoggingTransaction: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) - def _do_execute(self, func, sql: str, *args: Any) -> None: + def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R: sql = self._make_sql_one_line(sql) # TODO(paul): Maybe use 'info' and 'debug' for values? @@ -347,9 +364,6 @@ class PerformanceCounters: return top_n_counters -R = TypeVar("R") - - class DatabasePool: """Wraps a single physical database and connection pool. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ebffd89251..8326640d20 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Cursor from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) +class _NoChainCoverIndex(Exception): + def __init__(self, room_id: str): + super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) + + class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas The set of the difference in auth chains. """ + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_difference_chains", + self._get_auth_chain_difference_using_cover_index_txn, + room_id, + state_sets, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, ) + def _get_auth_chain_difference_using_cover_index_txn( + self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: + """Calculates the auth chain difference using the chain index. + + See docs/auth_chain_difference_algorithm.md for details + """ + + # First we look up the chain ID/sequence numbers for all the events, and + # work out the chain/sequence numbers reachable from each state set. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Map from event_id -> (chain ID, seq no) + chain_info = {} # type: Dict[str, Tuple[int, int]] + + # Map from chain ID -> seq no -> event Id + chain_to_event = {} # type: Dict[int, Dict[int, str]] + + # All the chains that we've found that are reachable from the state + # sets. + seen_chains = set() # type: Set[int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + chain_info[event_id] = (chain_id, sequence_number) + seen_chains.add(chain_id) + chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(chain_info) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Corresponds to `state_sets`, except as a map from chain ID to max + # sequence number reachable from the state set. + set_to_chain = [] # type: List[Dict[int, int]] + for state_set in state_sets: + chains = {} # type: Dict[int, int] + set_to_chain.append(chains) + + for event_id in state_set: + chain_id, seq_no = chain_info[event_id] + + chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) + + # Now we look up all links for the chains we have, adding chains to + # set_to_chain that are reachable from each set. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # (We need to take a copy of `seen_chains` as we want to mutate it in + # the loop) + for batch in batch_iter(set(seen_chains), 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + for chains in set_to_chain: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. + if origin_sequence_number <= chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, chains.get(target_chain_id, 0), + ) + + seen_chains.add(target_chain_id) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* state set and the minimum sequence number reachable from + # *all* state sets. Events in that range are in the auth chain + # difference. + result = set() + + # Mapping from chain ID to the range of sequence numbers that should be + # pulled from the database. + chain_to_gap = {} # type: Dict[int, Tuple[int, int]] + + for chain_id in seen_chains: + min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) + max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) + + if min_seq_no < max_seq_no: + # We have a non empty gap, try and fill it from the events that + # we have, otherwise add them to the list of gaps to pull out + # from the DB. + for seq_no in range(min_seq_no + 1, max_seq_no + 1): + event_id = chain_to_event.get(chain_id, {}).get(seq_no) + if event_id: + result.add(event_id) + else: + chain_to_gap[chain_id] = (min_seq_no, max_seq_no) + break + + if not chain_to_gap: + # If there are no gaps to fetch, we're done! + return result + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + """ + + args = [ + (chain_id, min_no, max_no) + for chain_id, (min_no, max_no) in chain_to_gap.items() + ] + + rows = txn.execute_values(sql, args) + result.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? + """ + for chain_id, (min_no, max_no) in chain_to_gap.items(): + txn.execute(sql, (chain_id, min_no, max_no)) + result.update(r for r, in txn) + + return result + def _get_auth_chain_difference_txn( self, txn, state_sets: List[Set[str]] ) -> Set[str]: + """Calculates the auth chain difference using a breadth first search. + + This is used when we don't have a cover index for the room. + """ # Algorithm Description # ~~~~~~~~~~~~~~~~~~~~~ diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5e7753e09b..186f064036 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,17 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder -from synapse.util.iterutils import batch_iter +from synapse.util.iterutils import batch_iter, sorted_topologically if TYPE_CHECKING: from synapse.server import HomeServer @@ -89,6 +100,14 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self._event_chain_id_gen = build_sequence_generator( + db.engine, get_chain_id_txn, "event_auth_chain_id" + ) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -366,6 +385,36 @@ class PersistEventsStore: # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _persist_event_auth_chain_txn( + self, txn: LoggingTransaction, events: List[EventBase], + ) -> None: + + # We only care about state events, so this if there are no state events. + if not any(e.is_state() for e in events): + return + # We want to store event_auth mappings for rejected events, as they're # used in state res v2. # This is only necessary if the rejected event appears in an accepted @@ -381,31 +430,357 @@ class PersistEventsStore: "room_id": event.room_id, "auth_id": auth_id, } - for event, _ in events_and_contexts + for event in events for auth_id in event.auth_event_ids() if event.is_state() ], ) - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts + # We now calculate chain ID/sequence numbers for any state events we're + # persisting. We ignore out of band memberships as we're not in the room + # and won't have their auth chain (we'll fix it up later if we join the + # room). + # + # See: docs/auth_chain_difference_algorithm.md + + # We ignore legacy rooms that we aren't filling the chain cover index + # for. + rows = self.db_pool.simple_select_many_txn( + txn, + table="rooms", + column="room_id", + iterable={event.room_id for event in events if event.is_state()}, + keyvalues={}, + retcols=("room_id", "has_auth_chain_index"), ) + rooms_using_chain_index = { + row["room_id"] for row in rows if row["has_auth_chain_index"] + } - # From this point onwards the events are only ones that weren't - # rejected. + state_events = { + event.event_id: event + for event in events + if event.is_state() and event.room_id in rooms_using_chain_index + } - self._update_metadata_tables_txn( + if not state_events: + return + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] + + # We need to know the type/state_key and auth events of the events we're + # calculating chain IDs for. We don't rely on having the full Event + # instances as we'll potentially be pulling more events from the DB and + # we don't need the overhead of fetching/parsing the full event JSON. + event_to_types = { + e.event_id: (e.type, e.state_key) for e in state_events.values() + } + event_to_auth_chain = { + e.event_id: e.auth_event_ids() for e in state_events.values() + } + + # Set of event IDs to calculate chain ID/seq numbers for. + events_to_calc_chain_id_for = set(state_events) + + # We check if there are any events that need to be handled in the rooms + # we're looking at. These should just be out of band memberships, where + # we didn't have the auth chain when we first persisted. + rows = self.db_pool.simple_select_many_txn( txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="room_id", + iterable={e.room_id for e in state_events.values()}, + retcols=("event_id", "type", "state_key"), ) + for row in rows: + event_id = row["event_id"] + event_type = row["type"] + state_key = row["state_key"] + + # (We could pull out the auth events for all rows at once using + # simple_select_many, but this case happens rarely and almost always + # with a single row.) + auth_events = self.db_pool.simple_select_onecol_txn( + txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", + ) - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + events_to_calc_chain_id_for.add(event_id) + event_to_types[event_id] = (event_type, state_key) + event_to_auth_chain[event_id] = auth_events + + # First we get the chain ID and sequence numbers for the events' + # auth events (that aren't also currently being persisted). + # + # Note that there there is an edge case here where we might not have + # calculated chains and sequence numbers for events that were "out + # of band". We handle this case by fetching the necessary info and + # adding it to the set of events to calculate chain IDs for. + + missing_auth_chains = { + a_id + for auth_events in event_to_auth_chain.values() + for a_id in auth_events + if a_id not in events_to_calc_chain_id_for + } + + # We loop here in case we find an out of band membership and need to + # fetch their auth event info. + while missing_auth_chains: + sql = """ + SELECT event_id, events.type, state_key, chain_id, sequence_number + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + WHERE + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", missing_auth_chains, + ) + txn.execute(sql + clause, args) + + missing_auth_chains.clear() + + for auth_id, event_type, state_key, chain_id, sequence_number in txn: + event_to_types[auth_id] = (event_type, state_key) + + if chain_id is None: + # No chain ID, so the event was persisted out of band. + # We add to list of events to calculate auth chains for. + + events_to_calc_chain_id_for.add(auth_id) + + event_to_auth_chain[ + auth_id + ] = self.db_pool.simple_select_onecol_txn( + txn, + "event_auth", + keyvalues={"event_id": auth_id}, + retcol="auth_id", + ) + + missing_auth_chains.update( + e + for e in event_to_auth_chain[auth_id] + if e not in event_to_types + ) + else: + chain_map[auth_id] = (chain_id, sequence_number) + + # Now we check if we have any events where we don't have auth chain, + # this should only be out of band memberships. + for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain): + for auth_id in event_to_auth_chain[event_id]: + if ( + auth_id not in chain_map + and auth_id not in events_to_calc_chain_id_for + ): + events_to_calc_chain_id_for.discard(event_id) + + # If this is an event we're trying to persist we add it to + # the list of events to calculate chain IDs for next time + # around. (Otherwise we will have already added it to the + # table). + event = state_events.get(event_id) + if event: + self.db_pool.simple_insert_txn( + txn, + table="event_auth_chain_to_calculate", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + ) + + # We stop checking the event's auth events since we've + # discarded it. + break + + if not events_to_calc_chain_id_for: + return + + # We now calculate the chain IDs/sequence numbers for the events. We + # do this by looking at the chain ID and sequence number of any auth + # event with the same type/state_key and incrementing the sequence + # number by one. If there was no match or the chain ID/sequence + # number is already taken we generate a new chain. + # + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events + # before the event itself. + chains_tuples_allocated = set() # type: Set[Tuple[int, int]] + new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + existing_chain_id = None + for auth_id in event_to_auth_chain[event_id]: + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map[auth_id] + break + + new_chain_tuple = None + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: + already_allocated = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_auth_chains", + keyvalues={ + "chain_id": proposed_new_id, + "sequence_number": proposed_new_seq, + }, + retcol="event_id", + allow_none=True, + ) + if already_allocated: + # Mark it as already allocated so we don't need to hit + # the DB again. + chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) + else: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + if not new_chain_tuple: + new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1) + + chains_tuples_allocated.add(new_chain_tuple) + + chain_map[event_id] = new_chain_tuple + new_chain_tuples[event_id] = new_chain_tuple + + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth_chains", + values=[ + {"event_id": event_id, "chain_id": c_id, "sequence_number": seq} + for event_id, (c_id, seq) in new_chain_tuples.items() + ], + ) + + self.db_pool.simple_delete_many_txn( + txn, + table="event_auth_chain_to_calculate", + keyvalues={}, + column="event_id", + iterable=new_chain_tuples, + ) + + # Now we need to calculate any new links between chains caused by + # the new events. + # + # Links are pairs of chain ID/sequence numbers such that for any + # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain + # if and only if there is at least one link (CA, S1) -> (CB, S2) + # where SA >= S1 and S2 >= SB. + # + # We try and avoid adding redundant links to the table, e.g. if we + # have two links between two chains which both start/end at the + # sequence number event (or cross) then one can be safely dropped. + # + # To calculate new links we look at every new event and: + # 1. Fetch the chain ID/sequence numbers of its auth events, + # discarding any that are reachable by other auth events, or + # that have the same chain ID as the event. + # 2. For each retained auth event we: + # a. Add a link from the event's to the auth event's chain + # ID/sequence number; and + # b. Add a link from the event to every chain reachable by the + # auth event. + + # Step 1, fetch all existing links from all the chains we've seen + # referenced. + chain_links = _LinkMap() + rows = self.db_pool.simple_select_many_txn( + txn, + table="event_auth_chain_links", + column="origin_chain_id", + iterable={chain_id for chain_id, _ in chain_map.values()}, + keyvalues={}, + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + ) + for row in rows: + chain_links.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + new=False, + ) + + # We do this in toplogical order to avoid adding redundant links. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + chain_id, sequence_number = chain_map[event_id] + + # Filter out auth events that are reachable by other auth + # events. We do this by looking at every permutation of pairs of + # auth events (A, B) to check if B is reachable from A. + reduction = { + a_id + for a_id in event_to_auth_chain[event_id] + if chain_map[a_id][0] != chain_id + } + for start_auth_id, end_auth_id in itertools.permutations( + event_to_auth_chain[event_id], r=2, + ): + if chain_links.exists_path_from( + chain_map[start_auth_id], chain_map[end_auth_id] + ): + reduction.discard(end_auth_id) + + # Step 2, figure out what the new links are from the reduced + # list of auth events. + for auth_id in reduction: + auth_chain_id, auth_sequence_number = chain_map[auth_id] + + # Step 2a, add link between the event and auth event + chain_links.add_link( + (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) + ) + + # Step 2b, add a link to chains reachable from the auth + # event. + for target_id, target_seq in chain_links.get_links_from( + (auth_chain_id, auth_sequence_number) + ): + if target_id == chain_id: + continue + + chain_links.add_link( + (chain_id, sequence_number), (target_id, target_seq) + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="event_auth_chain_links", + values=[ + { + "origin_chain_id": source_id, + "origin_sequence_number": source_seq, + "target_chain_id": target_id, + "target_sequence_number": target_seq, + } + for ( + source_id, + source_seq, + target_id, + target_seq, + ) in chain_links.get_additions() + ], + ) def _persist_transaction_ids_txn( self, @@ -1521,3 +1896,131 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + +@attr.s(slots=True) +class _LinkMap: + """A helper type for tracking links between chains. + """ + + # Stores the set of links as nested maps: source chain ID -> target chain ID + # -> source sequence number -> target sequence number. + maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict) + + # Stores the links that have been added (with new set to true), as tuples of + # `(source chain ID, source sequence no, target chain ID, target sequence no.)` + additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set) + + def add_link( + self, + src_tuple: Tuple[int, int], + target_tuple: Tuple[int, int], + new: bool = True, + ) -> bool: + """Add a new link between two chains, ensuring no redundant links are added. + + New links should be added in topological order. + + Args: + src_tuple: The chain ID/sequence number of the source of the link. + target_tuple: The chain ID/sequence number of the target of the link. + new: Whether this is a "new" link, i.e. should it be returned + by `get_additions`. + + Returns: + True if a link was added, false if the given link was dropped as redundant + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {}) + + assert src_chain != target_chain + + if new: + # Check if the new link is redundant + for current_seq_src, current_seq_target in current_links.items(): + # If a link "crosses" another link then its redundant. For example + # in the following link 1 (L1) is redundant, as any event reachable + # via L1 is *also* reachable via L2. + # + # Chain A Chain B + # | | + # L1 |------ | + # | | | + # L2 |---- | -->| + # | | | + # | |--->| + # | | + # | | + # + # So we only need to keep links which *do not* cross, i.e. links + # that both start and end above or below an existing link. + # + # Note, since we add links in topological ordering we should never + # see `src_seq` less than `current_seq_src`. + + if current_seq_src <= src_seq and target_seq <= current_seq_target: + # This new link is redundant, nothing to do. + return False + + self.additions.add((src_chain, src_seq, target_chain, target_seq)) + + current_links[src_seq] = target_seq + return True + + def get_links_from( + self, src_tuple: Tuple[int, int] + ) -> Generator[Tuple[int, int], None, None]: + """Gets the chains reachable from the given chain/sequence number. + + Yields: + The chain ID and sequence number the link points to. + """ + src_chain, src_seq = src_tuple + for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): + for link_src_seq, target_seq in sequence_numbers.items(): + if link_src_seq <= src_seq: + yield target_id, target_seq + + def get_links_between( + self, source_chain: int, target_chain: int + ) -> Generator[Tuple[int, int], None, None]: + """Gets the links between two chains. + + Yields: + The source and target sequence numbers. + """ + + yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() + + def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: + """Gets any newly added links. + + Yields: + The source chain ID/sequence number and target chain ID/sequence number + """ + + for src_chain, src_seq, target_chain, _ in self.additions: + target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq) + if target_seq is not None: + yield (src_chain, src_seq, target_chain, target_seq) + + def exists_path_from( + self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], + ) -> bool: + """Checks if there is a path between the source chain ID/sequence and + target chain ID/sequence. + """ + src_chain, src_seq = src_tuple + target_chain, target_seq = target_tuple + + if src_chain == target_chain: + return target_seq <= src_seq + + links = self.get_links_between(src_chain, target_chain) + for link_start_seq, link_end_seq in links: + if link_start_seq <= src_seq and target_seq <= link_end_seq: + return True + + return False diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4650d0689b..284f2ce77c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore): return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), + retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), desc="get_room", allow_none=True, ) @@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore): # It's overridden by RoomStore for the synapse master. raise NotImplementedError() + async def has_auth_chain_index(self, room_id: str) -> bool: + """Check if the room has (or can have) a chain cover index. + + Defaults to True if we don't have an entry in `rooms` table nor any + events for the room. + """ + + has_auth_chain_index = await self.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="has_auth_chain_index", + desc="has_auth_chain_index", + allow_none=True, + ) + + if has_auth_chain_index: + return True + + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + max_ordering = await self.db_pool.simple_select_one_onecol( + table="events", + keyvalues={"room_id": room_id}, + retcol="MAX(stream_ordering)", + allow_none=True, + desc="upsert_room_on_join", + ) + + return max_ordering is None + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): Called when we join a room over federation, and overwrites any room version currently in the table. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, values={"room_version": room_version.identifier}, - insertion_values={"is_public": False, "creator": ""}, + insertion_values={ + "is_public": False, + "creator": "", + "has_auth_chain_index": has_auth_chain_index, + }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. lock=False, @@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "creator": room_creator_user_id, "is_public": is_public, "room_version": room_version.identifier, + "has_auth_chain_index": True, }, ) if is_public: @@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): When we receive an invite or any other event over federation that may relate to a room we are not in, store the version of the room if we don't already know the room version. """ + # It's possible that we already have events for the room in our DB + # without a corresponding room entry. If we do then we don't want to + # mark the room as having an auth chain cover index. + has_auth_chain_index = await self.has_auth_chain_index(room_id) + await self.db_pool.simple_upsert( desc="maybe_store_room_on_outlier_membership", table="rooms", @@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "room_version": room_version.identifier, "is_public": False, "creator": "", + "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an # emulated upsert. diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql new file mode 100644 index 0000000000..729196cfd5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql @@ -0,0 +1,52 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- See docs/auth_chain_difference_algorithm.md + +CREATE TABLE event_auth_chains ( + event_id TEXT PRIMARY KEY, + chain_id BIGINT NOT NULL, + sequence_number BIGINT NOT NULL +); + +CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number); + + +CREATE TABLE event_auth_chain_links ( + origin_chain_id BIGINT NOT NULL, + origin_sequence_number BIGINT NOT NULL, + + target_chain_id BIGINT NOT NULL, + target_sequence_number BIGINT NOT NULL +); + + +CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id); + + +-- Events that we have persisted but not calculated auth chains for, +-- e.g. out of band memberships (where we don't have the auth chain) +CREATE TABLE event_auth_chain_to_calculate ( + event_id TEXT PRIMARY KEY, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL +); + +CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id); + + +-- Whether we've calculated the above index for a room. +ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN; diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres new file mode 100644 index 0000000000..e8a035bbeb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id; diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 06faeebe7f..f7b4857a84 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -13,8 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import heapq from itertools import islice -from typing import Iterable, Iterator, Sequence, Tuple, TypeVar +from typing import ( + Dict, + Generator, + Iterable, + Iterator, + Mapping, + Sequence, + Set, + Tuple, + TypeVar, +) + +from synapse.types import Collection T = TypeVar("T") @@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: If the input is empty, no chunks are returned. """ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) + + +def sorted_topologically( + nodes: Iterable[T], graph: Mapping[T, Collection[T]], +) -> Generator[T, None, None]: + """Given a set of nodes and a graph, yield the nodes in toplogical order. + + For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`. + """ + + # This is implemented by Kahn's algorithm. + + degree_map = {node: 0 for node in nodes} + reverse_graph = {} # type: Dict[T, Set[T]] + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in edges: + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + heapq.heapify(zero_degree) + + while zero_degree: + node = heapq.heappop(zero_degree) + yield node + + for edge in reverse_graph[node]: + if edge in degree_map: + degree_map[edge] -= 1 + if degree_map[edge] == 0: + heapq.heappush(zero_degree, edge) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py new file mode 100644 index 0000000000..83c377824b --- /dev/null +++ b/tests/storage/test_event_chain.py @@ -0,0 +1,472 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +from twisted.trial import unittest + +from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.storage.databases.main.events import _LinkMap + +from tests.unittest import HomeserverTestCase + + +class EventChainStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self._next_stream_ordering = 1 + + def test_simple(self): + """Test that the example in `docs/auth_chain_difference_algorithm.md` + works. + """ + + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + power_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + bob_join_2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join_2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[ + create.event_id, + alice_join.event_id, + power_2.event_id, + ], + ) + ) + + events = [ + create, + bob_join, + power, + alice_invite, + alice_join, + bob_join_2, + power_2, + alice_join2, + ] + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + (bob_join_2, power), + (alice_join2, power_2), + ] + + self.persist(events) + chain_map, link_map = self.fetch_chains(events) + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + # Test that everything can reach the create event, but the create event + # can't reach anything. + for event in events[1:]: + self.assertTrue( + link_map.exists_path_from( + chain_map[event.event_id], chain_map[create.event_id] + ), + ) + + self.assertFalse( + link_map.exists_path_from( + chain_map[create.event_id], chain_map[event.event_id], + ), + ) + + def test_out_of_order_events(self): + """Test that we handle persisting events that we don't have the full + auth chain for yet (which should only happen for out of band memberships). + """ + event_factory = self.hs.get_event_builder_factory() + bob = "@creator:test" + alice = "@alice:test" + room_id = "!room:test" + + # Ensure that we have a rooms entry so that we generate the chain index. + self.get_success( + self.store.store_room( + room_id=room_id, + room_creator_user_id="", + is_public=True, + room_version=RoomVersions.V6, + ) + ) + + # First persist the base room. + create = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Create, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "create"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[]) + ) + + bob_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": bob, + "sender": bob, + "room_id": room_id, + "content": {"tag": "bob_join"}, + }, + ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) + ) + + power = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.PowerLevels, + "state_key": "", + "sender": bob, + "room_id": room_id, + "content": {"tag": "power"}, + }, + ).build( + prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], + ) + ) + + self.persist([create, bob_join, power]) + + # Now persist an invite and a couple of memberships out of order. + alice_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": bob, + "room_id": room_id, + "content": {"tag": "alice_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], + ) + ) + + alice_join = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], + ) + ) + + alice_join2 = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": alice, + "sender": alice, + "room_id": room_id, + "content": {"tag": "alice_join2"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[create.event_id, alice_join.event_id, power.event_id], + ) + ) + + self.persist([alice_join]) + self.persist([alice_join2]) + self.persist([alice_invite]) + + # The end result should be sane. + events = [create, bob_join, power, alice_invite, alice_join] + + chain_map, link_map = self.fetch_chains(events) + + expected_links = [ + (bob_join, create), + (power, create), + (power, bob_join), + (alice_invite, create), + (alice_invite, power), + (alice_invite, bob_join), + ] + + # Check that the expected links and only the expected links have been + # added. + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) + + def persist( + self, events: List[EventBase], + ): + """Persist the given events and check that the links generated match + those given. + """ + + persist_events_store = self.hs.get_datastores().persist_events + + for e in events: + e.internal_metadata.stream_ordering = self._next_stream_ordering + self._next_stream_ordering += 1 + + def _persist(txn): + # We need to persist the events to the events and state_events + # tables. + persist_events_store._store_event_txn(txn, [(e, {}) for e in events]) + + # Actually call the function that calculates the auth chain stuff. + persist_events_store._persist_event_auth_chain_txn(txn, events) + + self.get_success( + persist_events_store.db_pool.runInteraction("_persist", _persist,) + ) + + def fetch_chains( + self, events: List[EventBase] + ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: + + # Fetch the map from event ID -> (chain ID, sequence number) + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chains", + column="event_id", + iterable=[e.event_id for e in events], + retcols=("event_id", "chain_id", "sequence_number"), + keyvalues={}, + ) + ) + + chain_map = { + row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows + } + + # Fetch all the links and pass them to the _LinkMap. + rows = self.get_success( + self.store.db_pool.simple_select_many_batch( + table="event_auth_chain_links", + column="origin_chain_id", + iterable=[chain_id for chain_id, _ in chain_map.values()], + retcols=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), + keyvalues={}, + ) + ) + + link_map = _LinkMap() + for row in rows: + added = link_map.add_link( + (row["origin_chain_id"], row["origin_sequence_number"]), + (row["target_chain_id"], row["target_sequence_number"]), + ) + + # We shouldn't have persisted any redundant links + self.assertTrue(added) + + return chain_map, link_map + + +class LinkMapTestCase(unittest.TestCase): + def test_simple(self): + """Basic tests for the LinkMap. + """ + link_map = _LinkMap() + + link_map.add_link((1, 1), (2, 1), new=False) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) + self.assertCountEqual(link_map.get_additions(), []) + self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) + self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) + self.assertTrue(link_map.exists_path_from((1, 5), (1, 1))) + self.assertFalse(link_map.exists_path_from((1, 1), (1, 5))) + + # Attempting to add a redundant link is ignored. + self.assertFalse(link_map.add_link((1, 4), (2, 1))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + + # Adding new non-redundant links works + self.assertTrue(link_map.add_link((1, 3), (2, 3))) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertTrue(link_map.add_link((2, 5), (1, 3))) + self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 482506d731..9d04a066d8 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,6 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import attr +from parameterized import parameterized + +from synapse.events import _EventInternalMetadata + import tests.unittest import tests.utils @@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - def test_auth_difference(self): + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): "j": 1, } + # Mark the room as not having a cover index + + def store_room(txn): + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": use_chain_cover_index, + }, + ) + + self.get_success(self.store.db_pool.runInteraction("store_room", store_room)) + # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn, event_id, stream_ordering): + def insert_event(txn): + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + ], + ) + + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) + self.assertSetEqual(difference, set()) + + def test_auth_difference_partial_cover(self): + """Test that we correctly handle rooms where not all events have a chain + cover calculated. This can happen in some obscure edge cases, including + during the background update that calculates the chain cover for old + rooms. + """ + + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } - depth = depth_map[event_id] + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + def insert_event(txn): + # First insert the room and mark it as having a chain cover. self.store.db_pool.simple_insert_txn( txn, - table="events", - values={ - "event_id": event_id, + "rooms", + { "room_id": room_id, - "depth": depth, - "topological_ordering": depth, - "type": "m.test", - "processed": True, - "outlier": False, - "stream_ordering": stream_ordering, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + "has_auth_chain_index": True, }, ) - self.store.db_pool.simple_insert_many_txn( + stream_ordering = 0 + + for event_id in auth_graph: + stream_ordering += 1 + depth = depth_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + # Insert all events apart from 'B' + self.hs.datastores.persist_events._persist_event_auth_chain_txn( txn, - table="event_auth", - values=[ - {"event_id": event_id, "room_id": room_id, "auth_id": a} - for a in auth_graph[event_id] + [ + FakeEvent(event_id, room_id, auth_graph[event_id]) + for event_id in auth_graph + if event_id != "b" ], ) - next_stream_ordering = 0 - for event_id in auth_graph: - next_stream_ordering += 1 - self.get_success( - self.store.db_pool.runInteraction( - "insert", insert_event, event_id, next_stream_ordering - ) + # Now we insert the event 'B' without a chain cover, by temporarily + # pretending the room doesn't have a chain cover. + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + ) + + self.hs.datastores.persist_events._persist_event_auth_chain_txn( + txn, [FakeEvent("b", room_id, auth_graph["b"])], + ) + + self.store.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": True}, ) + self.get_success(self.store.db_pool.runInteraction("insert", insert_event,)) + # Now actually test that various combinations give the right result: difference = self.get_success( @@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_auth_chain_difference(room_id, [{"a"}]) ) self.assertSetEqual(difference, set()) + + +@attr.s +class FakeEvent: + event_id = attr.ib() + room_id = attr.ib() + auth_events = attr.ib() + + type = "foo" + state_key = "foo" + + internal_metadata = _EventInternalMetadata({}) + + def auth_event_ids(self): + return self.auth_events + + def is_state(self): + return True diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 0ab0a91483..1184cea5a3 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.iterutils import chunk_seq +from typing import Dict, List + +from synapse.util.iterutils import chunk_seq, sorted_topologically from tests.unittest import TestCase @@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase): self.assertEqual( list(parts), [], ) + + +class SortTopologically(TestCase): + def test_empty(self): + "Test that an empty graph works correctly" + + graph = {} # type: Dict[int, List[int]] + self.assertEqual(list(sorted_topologically([], graph)), []) + + def test_disconnected(self): + "Test that a graph with no edges work" + + graph = {1: [], 2: []} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + + def test_linear(self): + "Test that a simple `4 -> 3 -> 2 -> 1` graph works" + + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_subset(self): + "Test that only sorting a subset of the graph works" + graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) + + def test_fork(self): + "Test that a forked graph works" + graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + + # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should + # always get the same one. + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From 42d3a28d8bb8c08e9e0d00a2e247cbbddb1a155c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 11 Jan 2021 17:15:54 +0100 Subject: Removes unnecessary declarations in the tests for the admin API. (#9063) --- changelog.d/9063.misc | 1 + tests/rest/admin/test_admin.py | 3 --- tests/rest/admin/test_event_reports.py | 4 ---- tests/rest/admin/test_media.py | 2 -- tests/rest/admin/test_room.py | 2 -- tests/rest/admin/test_statistics.py | 1 - tests/rest/admin/test_user.py | 5 ----- 7 files changed, 1 insertion(+), 17 deletions(-) create mode 100644 changelog.d/9063.misc (limited to 'tests') diff --git a/changelog.d/9063.misc b/changelog.d/9063.misc new file mode 100644 index 0000000000..22eed43147 --- /dev/null +++ b/changelog.d/9063.misc @@ -0,0 +1 @@ +Removes unnecessary declarations in the tests for the admin API. \ No newline at end of file diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0504cd187e..586b877bda 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -156,7 +154,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.hs = hs # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index aa389df12f..d0090faa4f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -371,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index c2b998cdae..51a7731693 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -181,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index fa620f97f3..a0f32c5512 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -605,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 73f8a8ec99..f48be3d65a 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9b2e4765f6..877fd2587b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1204,8 +1204,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1401,7 +1399,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -1868,8 +1865,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1 From b161528fccac4bf17f7afa9438a75d796433194e Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 11 Jan 2021 20:32:17 +0100 Subject: Also support remote users on the joined_rooms admin API. (#8948) For remote users, only the rooms which the server knows about are returned. Local users have all of their joined rooms returned. --- changelog.d/8948.feature | 1 + docs/admin_api/user_admin_api.rst | 4 +++ synapse/rest/admin/users.py | 7 ----- tests/rest/admin/test_user.py | 58 +++++++++++++++++++++++++++++++++++---- 4 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 changelog.d/8948.feature (limited to 'tests') diff --git a/changelog.d/8948.feature b/changelog.d/8948.feature new file mode 100644 index 0000000000..3b06cbfa22 --- /dev/null +++ b/changelog.d/8948.feature @@ -0,0 +1 @@ +Update `/_synapse/admin/v1/users//joined_rooms` to work for both local and remote users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index e4d6f8203b..3115951e1f 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -337,6 +337,10 @@ A response body like the following is returned: "total": 2 } +The server returns the list of rooms of which the user and the server +are member. If the user is local, all the rooms of which the user is +member are returned. + **Parameters** The following parameters should be set in the URL: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6658c2da56..f8a73e7d9d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -714,13 +714,6 @@ class UserMembershipRestServlet(RestServlet): async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 877fd2587b..ad4588c1da 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,6 +25,7 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError +from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync @@ -1234,24 +1235,26 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" channel = self.make_request("GET", url, access_token=self.admin_user_tok,) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) def test_no_memberships(self): """ @@ -1282,6 +1285,49 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + def test_get_rooms_with_nonlocal_user(self): + """ + Tests that a normal lookup for rooms is successful with a non-local user + """ + + other_user_tok = self.login("user", "pass") + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + storage = self.hs.get_storage() + + # Create two rooms, one with a local user only and one with both a local + # and remote user. + self.helper.create_room_as(self.other_user, tok=other_user_tok) + local_and_remote_room_id = self.helper.create_room_as( + self.other_user, tok=other_user_tok + ) + + # Add a remote user to the room. + builder = event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@joiner:remote_hs", + "state_key": "@joiner:remote_hs", + "room_id": local_and_remote_room_id, + "content": {"membership": "join"}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success(storage.persistence.persist_event(event, context)) + + # Now get rooms + url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) + class PushersRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 2ec8ca5e6046b207eabe008101631b978758ac6d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 12 Jan 2021 12:34:16 +0000 Subject: Remove SynapseRequest.get_user_agent (#9069) SynapseRequest is in danger of becoming a bit of a dumping-ground for "useful stuff relating to Requests", which isn't really its intention (its purpose is to override render, finished and connectionLost to set up the LoggingContext and write the right entries to the request log). Putting utility functions inside SynapseRequest means that lots of our code ends up requiring a SynapseRequest when there is nothing synapse-specific about the Request at all, and any old twisted.web.iweb.IRequest will do. This increases code coupling and makes testing more difficult. In short: move get_user_agent out to a utility function. --- changelog.d/9069.misc | 1 + synapse/api/auth.py | 3 ++- synapse/handlers/auth.py | 6 +++--- synapse/handlers/sso.py | 5 +++-- synapse/http/__init__.py | 15 +++++++++++++++ synapse/http/site.py | 18 ++---------------- tests/handlers/test_cas.py | 2 +- tests/handlers/test_oidc.py | 3 +-- tests/handlers/test_saml.py | 2 +- 9 files changed, 29 insertions(+), 26 deletions(-) create mode 100644 changelog.d/9069.misc (limited to 'tests') diff --git a/changelog.d/9069.misc b/changelog.d/9069.misc new file mode 100644 index 0000000000..5e9e62d252 --- /dev/null +++ b/changelog.d/9069.misc @@ -0,0 +1 @@ +Remove `SynapseRequest.get_user_agent`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 48c4d7b0be..6d6703250b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.appservice import ApplicationService from synapse.events import EventBase +from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing from synapse.storage.databases.main.registration import TokenLookupResult @@ -187,7 +188,7 @@ class Auth: """ try: ip_addr = self.hs.get_ip_from_request(request) - user_agent = request.get_user_agent("") + user_agent = get_request_user_agent(request) access_token = self.get_access_token_from_request(request) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f4434673dc..5b86ee85c7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -49,8 +49,10 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.handlers._base import BaseHandler from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http import get_request_user_agent from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread @@ -62,8 +64,6 @@ from synapse.util.async_helpers import maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -539,7 +539,7 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) - user_agent = request.get_user_agent("") + user_agent = get_request_user_agent(request) await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 2da1ea2223..740df7e4a0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,6 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters @@ -362,7 +363,7 @@ class SsoHandler: attributes, auth_provider_id, remote_user_id, - request.get_user_agent(""), + get_request_user_agent(request), request.getClientIP(), ) @@ -628,7 +629,7 @@ class SsoHandler: attributes, session.auth_provider_id, session.remote_user_id, - request.get_user_agent(""), + get_request_user_agent(request), request.getClientIP(), ) diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 59b01b812c..4bc3cb53f0 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -17,6 +17,7 @@ import re from twisted.internet import task from twisted.web.client import FileBodyProducer +from twisted.web.iweb import IRequest from synapse.api.errors import SynapseError @@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer): FileBodyProducer.stopProducing(self) except task.TaskStopped: pass + + +def get_request_user_agent(request: IRequest, default: str = "") -> str: + """Return the last User-Agent header, or the given default. + """ + # There could be raw utf-8 bytes in the User-Agent header. + + # N.B. if you don't do this, the logger explodes cryptically + # with maximum recursion trying to log errors about + # the charset problem. + # c.f. https://github.com/matrix-org/synapse/issues/3471 + + h = request.getHeader(b"User-Agent") + return h.decode("ascii", "replace") if h else default diff --git a/synapse/http/site.py b/synapse/http/site.py index 5a5790831b..12ec3f851f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ from twisted.python.failure import Failure from twisted.web.server import Request, Site from synapse.config.server import ListenerConfig -from synapse.http import redact_uri +from synapse.http import get_request_user_agent, redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.types import Requester @@ -113,15 +113,6 @@ class SynapseRequest(Request): method = self.method.decode("ascii") return method - def get_user_agent(self, default: str) -> str: - """Return the last User-Agent header, or the given default. - """ - user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] - if user_agent is None: - return default - - return user_agent.decode("ascii", "replace") - def render(self, resrc): # this is called once a Resource has been found to serve the request; in our # case the Resource in question will normally be a JsonResource. @@ -292,12 +283,7 @@ class SynapseRequest(Request): # and can see that we're doing something wrong. authenticated_entity = repr(self.requester) # type: ignore[unreachable] - # ...or could be raw utf-8 bytes in the User-Agent header. - # N.B. if you don't do this, the logger explodes cryptically - # with maximum recursion trying to log errors about - # the charset problem. - # c.f. https://github.com/matrix-org/synapse/issues/3471 - user_agent = self.get_user_agent("-") + user_agent = get_request_user_agent(self, "-") code = str(self.code) if not self.finished: diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index bd7a1b6891..c37bb6440e 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -118,4 +118,4 @@ class CasHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "get_user_agent"]) + return Mock(spec=["getClientIP", "getHeader"]) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index f5df657814..4ce0f74f22 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -1011,7 +1011,7 @@ def _build_callback_request( "addCookie", "requestHeaders", "getClientIP", - "get_user_agent", + "getHeader", ] ) @@ -1020,5 +1020,4 @@ def _build_callback_request( request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent return request diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 548038214b..261c7083d1 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -262,4 +262,4 @@ class SamlHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "get_user_agent"]) + return Mock(spec=["getClientIP", "getHeader"]) -- cgit 1.5.1 From 723b19748aca0de4ae0eb2d6d89844fad04f29ae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 12 Jan 2021 11:07:01 -0500 Subject: Handle bad JSON data being returned from the federation API. (#9070) --- changelog.d/9070.bugfix | 1 + synapse/http/matrixfederationclient.py | 10 ++++++++++ tests/http/test_fedclient.py | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9070.bugfix (limited to 'tests') diff --git a/changelog.d/9070.bugfix b/changelog.d/9070.bugfix new file mode 100644 index 0000000000..72b8fe9f1c --- /dev/null +++ b/changelog.d/9070.bugfix @@ -0,0 +1 @@ +Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b261e078c4..b7103d6541 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -174,6 +174,16 @@ async def _handle_json_response( d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = await make_deferred_yieldable(d) + except ValueError as e: + # The JSON content was invalid. + logger.warning( + "{%s} [%s] Failed to parse JSON response - %s %s", + request.txn_id, + request.destination, + request.method, + request.uri.decode("ascii"), + ) + raise RequestSendFailed(e, can_retry=False) from e except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 212484a7fe..9c52c8fdca 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase): self.pump() f = self.failureResultOf(test_d) - self.assertIsInstance(f.value, ValueError) + self.assertIsInstance(f.value, RequestSendFailed) -- cgit 1.5.1 From e385c8b4734b95c0738d6f4022d7bbb1621167db Mon Sep 17 00:00:00 2001 From: Marcus Date: Tue, 12 Jan 2021 18:20:30 +0100 Subject: Don't apply the IP range blacklist to proxy connections (#9084) It is expected that the proxy would be on a private IP address so the configured proxy should be connected to regardless of the IP range blacklist. --- changelog.d/9084.bugfix | 1 + synapse/http/client.py | 1 + synapse/http/proxyagent.py | 16 +++++- tests/http/test_proxyagent.py | 130 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9084.bugfix (limited to 'tests') diff --git a/changelog.d/9084.bugfix b/changelog.d/9084.bugfix new file mode 100644 index 0000000000..415dd8b259 --- /dev/null +++ b/changelog.d/9084.bugfix @@ -0,0 +1 @@ +Don't blacklist connections to the configured proxy. Contributed by @Bubu. diff --git a/synapse/http/client.py b/synapse/http/client.py index 29f40ddf5f..5f74ee1149 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -341,6 +341,7 @@ class SimpleHttpClient: self.agent = ProxyAgent( self.reactor, + hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index e32d3f43e0..b730d2c634 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase): reactor: twisted reactor to place outgoing connections. + proxy_reactor: twisted reactor to use for connections to the proxy server + reactor might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the verification parameters of OpenSSL. The default is to use a `BrowserLikePolicyForHTTPS`, so unless you have special @@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase): def __init__( self, reactor, + proxy_reactor=None, contextFactory=BrowserLikePolicyForHTTPS(), connectTimeout=None, bindAddress=None, @@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase): ): _AgentBase.__init__(self, reactor, pool) + if proxy_reactor is None: + self.proxy_reactor = reactor + else: + self.proxy_reactor = proxy_reactor + self._endpoint_kwargs = {} if connectTimeout is not None: self._endpoint_kwargs["timeout"] = connectTimeout @@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase): self._endpoint_kwargs["bindAddress"] = bindAddress self.http_proxy_endpoint = _http_proxy_endpoint( - http_proxy, reactor, **self._endpoint_kwargs + http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) self.https_proxy_endpoint = _http_proxy_endpoint( - https_proxy, reactor, **self._endpoint_kwargs + https_proxy, self.proxy_reactor, **self._endpoint_kwargs ) self._policy_for_https = contextFactory @@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase): request_path = uri elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: endpoint = HTTPConnectProxyEndpoint( - self._reactor, + self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 22abf76515..9a56e1c14a 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -15,12 +15,14 @@ import logging import treq +from netaddr import IPSet from twisted.internet import interfaces # noqa: F401 from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.web.http import HTTPChannel +from synapse.http.client import BlacklistingReactorWrapper from synapse.http.proxyagent import ProxyAgent from tests.http import TestServerTLSConnectionFactory, get_test_https_policy @@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") + def test_http_request_via_proxy_with_blacklist(self): + # The blacklist includes the configured proxy IP. + agent = ProxyAgent( + BlacklistingReactorWrapper( + self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) + ), + self.reactor, + http_proxy=b"proxy.com:8888", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy_with_blacklist(self): + # The blacklist includes the configured proxy IP. + agent = ProxyAgent( + BlacklistingReactorWrapper( + self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) + ), + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + def _wrap_server_factory_for_tls(factory, sanlist=None): """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory -- cgit 1.5.1 From 7a2e9b549defe3f55531711a863183a33e7af83c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 12 Jan 2021 22:30:15 +0100 Subject: Remove user's avatar URL and displayname when deactivated. (#8932) This only applies if the user's data is to be erased. --- changelog.d/8932.feature | 1 + docs/admin_api/user_admin_api.rst | 21 +++ synapse/handlers/deactivate_account.py | 18 ++- synapse/handlers/profile.py | 8 +- synapse/rest/admin/users.py | 22 ++- synapse/rest/client/v2_alpha/account.py | 7 +- synapse/server.py | 2 +- synapse/storage/databases/main/profile.py | 2 +- tests/handlers/test_profile.py | 30 ++++ tests/rest/admin/test_user.py | 220 ++++++++++++++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/storage/test_profile.py | 26 ++++ 13 files changed, 351 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8932.feature (limited to 'tests') diff --git a/changelog.d/8932.feature b/changelog.d/8932.feature new file mode 100644 index 0000000000..a1d17394d7 --- /dev/null +++ b/changelog.d/8932.feature @@ -0,0 +1 @@ +Remove a user's avatar URL and display name when deactivated with the Admin API. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 3115951e1f..b3d413cf57 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -98,6 +98,8 @@ Body parameters: - ``deactivated``, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to ``false`` for new accounts. + A user cannot be erased by deactivating with this API. For details on deactivating users see + `Deactivate Account <#deactivate-account>`_. If the user already exists then optional parameters default to the current value. @@ -248,6 +250,25 @@ server admin: see `README.rst `_. The erase parameter is optional and defaults to ``false``. An empty body may be passed for backwards compatibility. +The following actions are performed when deactivating an user: + +- Try to unpind 3PIDs from the identity server +- Remove all 3PIDs from the homeserver +- Delete all devices and E2EE keys +- Delete all access tokens +- Delete the password hash +- Removal from all rooms the user is a member of +- Remove the user from the user directory +- Reject all pending invites +- Remove all account validity information related to the user + +The following additional actions are performed during deactivation if``erase`` +is set to ``true``: + +- Remove the user's display name +- Remove the user's avatar URL +- Mark the user as erased + Reset password ============== diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index e808142365..c4a3b26a84 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler @@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_identity_handler() + self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname @@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled async def deactivate_account( - self, user_id: str, erase_data: bool, id_server: Optional[str] = None + self, + user_id: str, + erase_data: bool, + requester: Requester, + id_server: Optional[str] = None, + by_admin: bool = False, ) -> bool: """Deactivate a user's account Args: user_id: ID of user to be deactivated erase_data: whether to GDPR-erase the user's data + requester: The user attempting to make this change. id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). + by_admin: Whether this change was made by an administrator. Returns: True if identity server supports removing threepids, otherwise False. @@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler): # Mark the user as erased, if they asked for that if erase_data: + user = UserID.from_string(user_id) + # Remove avatar URL from this user + await self._profile_handler.set_avatar_url(user, requester, "", by_admin) + # Remove displayname from this user + await self._profile_handler.set_displayname(user, requester, "", by_admin) + logger.info("Marking %s as erased", user_id) await self.store.mark_user_erased(user_id) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 36f9ee4b71..c02b951031 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + avatar_url_to_set = new_avatar_url # type: Optional[str] + if new_avatar_url == "": + avatar_url_to_set = None + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) - await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url( + target_user.localpart, avatar_url_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f8a73e7d9d..f39e3d6d5c 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet): if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( - target_user.to_string(), False + target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: if "password" not in body: @@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + + async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(400, "Can only deactivate local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") - async def on_POST(self, request, target_user_id): - await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - UserID.from_string(target_user_id) - result = await self._deactivate_account_handler.deactivate_account( - target_user_id, erase + target_user_id, erase, requester, by_admin=True ) if result: id_server_unbind_result = "success" diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3b50dc885f..65e68d641b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -305,7 +305,7 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase + requester.user.to_string(), erase, requester ) return 200, {} @@ -313,7 +313,10 @@ class DeactivateAccountRestServlet(RestServlet): requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, id_server=body.get("id_server") + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" diff --git a/synapse/server.py b/synapse/server.py index 12da92b63c..d4c235cda5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -501,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta): return InitialSyncHandler(self) @cache_in_self - def get_profile_handler(self): + def get_profile_handler(self) -> ProfileHandler: return ProfileHandler(self) @cache_in_self diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 0e25ca3d7a..54ef0f1f54 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: str + self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 919547556b..022943a10a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase): "Frank", ) + # Set displayname to an empty string + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "" + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False @@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase): "http://my.server/me.png", ) + # Set avatar to an empty string + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, synapse.types.create_requester(self.frank), "", + ) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), + ) + @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ad4588c1da..04599c2fcf 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -588,6 +588,200 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "bar", "user_id") +class DeactivateAccountTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) + self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote( + self.other_user + ) + + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + def test_no_auth(self): + """ + Try to deactivate users without authentication. + """ + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + url = "/_synapse/admin/v1/deactivate/@bob:test" + + channel = self.make_request("POST", url, access_token=self.other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + channel = self.make_request( + "POST", url, access_token=self.other_user_token, content=b"{}", + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_user_does_not_exist(self): + """ + Tests that deactivation for a user that does not exist returns a 404 + """ + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/deactivate/@unknown_person:test", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_erase_is_not_bool(self): + """ + If parameter `erase` is not boolean, return an error + """ + body = json.dumps({"erase": "False"}) + + channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that deactivation for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only deactivate local users", channel.json_body["error"]) + + def test_deactivate_user_erase_true(self): + """ + Test deactivating an user and set `erase` to `true` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": True}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + + def test_deactivate_user_erase_false(self): + """ + Test deactivating an user and set `erase` to `false` + """ + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + # Deactivate user + body = json.dumps({"erase": False}) + + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User1", channel.json_body["displayname"]) + + self._is_erased("@user:test", False) + + def _is_erased(self, user_id: str, expect: bool) -> None: + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + + class UserRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -987,6 +1181,26 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test deactivating another user. """ + # set attributes for user + self.get_success( + self.store.set_profile_avatar_url("user", "mxc://servername/mediaid") + ) + self.get_success( + self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) + ) + + # Get user + channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) + # Deactivate user body = json.dumps({"deactivated": True}) @@ -1000,6 +1214,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) # the user is deactivated, the threepid will be deleted # Get user @@ -1010,6 +1227,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) + self.assertEqual("User", channel.json_body["displayname"]) @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self): diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 1d1dc9f8a2..f9b8011961 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC @@ -667,7 +668,9 @@ class CASTestCase(unittest.HomeserverTestCase): # Deactivate the account. self.get_success( - self.deactivate_account_handler.deactivate_account(self.user_id, False) + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) ) # Request the CAS ticket. diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6105eac47c..d4e3165436 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias, UserID +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase): deactivate_account_handler = self.hs.get_deactivate_account_handler() self.get_success( - deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) ) # Invite another user in the room. This is needed because messages will be diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 3fd0a38cf5..ea63bd56b4 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase): ), ) + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ) + ) + @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) @@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase): ) ), ) + + # test set to None + yield defer.ensureDeferred( + self.store.set_profile_avatar_url(self.u_frank.localpart, None) + ) + + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ) + ) -- cgit 1.5.1 From bc4bf7b384d88189e2f9c5d1d4e00960a42792f5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 13 Jan 2021 10:26:12 +0000 Subject: Preparatory refactors of OidcHandler (#9067) Some light refactoring of OidcHandler, in preparation for bigger things: * remove inheritance from deprecated BaseHandler * add an object to hold the things that go into a session cookie * factor out a separate class for manipulating said cookies --- changelog.d/9067.feature | 1 + synapse/handlers/oidc_handler.py | 304 +++++++++++++++++++++------------------ tests/handlers/test_oidc.py | 61 ++++---- 3 files changed, 201 insertions(+), 165 deletions(-) create mode 100644 changelog.d/9067.feature (limited to 'tests') diff --git a/changelog.d/9067.feature b/changelog.d/9067.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9067.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 6835c6c462..88097639ef 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect import logging -from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar from urllib.parse import urlencode import attr @@ -35,7 +35,6 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -85,12 +84,15 @@ class OidcError(Exception): return self.error -class OidcHandler(BaseHandler): +class OidcHandler: """Handles requests related to the OpenID Connect login flow. """ def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self._store = hs.get_datastore() + + self._token_generator = OidcSessionTokenGenerator(hs) + self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._user_profile_method = hs.config.oidc_user_profile_method # type: str @@ -116,7 +118,6 @@ class OidcHandler(BaseHandler): self._http_client = hs.get_proxied_http_client() self._server_name = hs.config.server_name # type: str - self._macaroon_secret_key = hs.config.macaroon_secret_key # identifier for the external_ids table self.idp_id = "oidc" @@ -519,11 +520,13 @@ class OidcHandler(BaseHandler): if not client_redirect_url: client_redirect_url = b"" - cookie = self._generate_oidc_session_token( + cookie = self._token_generator.generate_oidc_session_token( state=state, - nonce=nonce, - client_redirect_url=client_redirect_url.decode(), - ui_auth_session_id=ui_auth_session_id, + session_data=OidcSessionData( + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id, + ), ) request.addCookie( SESSION_COOKIE_NAME, @@ -628,11 +631,9 @@ class OidcHandler(BaseHandler): # Deserialize the session token and verify it. try: - ( - nonce, - client_redirect_url, - ui_auth_session_id, - ) = self._verify_oidc_session_token(session, state) + session_data = self._token_generator.verify_oidc_session_token( + session, state + ) except MacaroonDeserializationException as e: logger.exception("Invalid session") self._sso_handler.render_error(request, "invalid_session", str(e)) @@ -674,14 +675,14 @@ class OidcHandler(BaseHandler): else: logger.debug("Extracting userinfo from id_token") try: - userinfo = await self._parse_id_token(token, nonce=nonce) + userinfo = await self._parse_id_token(token, nonce=session_data.nonce) except Exception as e: logger.exception("Invalid id_token") self._sso_handler.render_error(request, "invalid_token", str(e)) return # first check if we're doing a UIA - if ui_auth_session_id: + if session_data.ui_auth_session_id: try: remote_user_id = self._remote_id_from_userinfo(userinfo) except Exception as e: @@ -690,7 +691,7 @@ class OidcHandler(BaseHandler): return return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, remote_user_id, ui_auth_session_id, request + self.idp_id, remote_user_id, session_data.ui_auth_session_id, request ) # otherwise, it's a login @@ -698,133 +699,12 @@ class OidcHandler(BaseHandler): # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, client_redirect_url + userinfo, token, request, session_data.client_redirect_url ) except MappingException as e: logger.exception("Could not map user") self._sso_handler.render_error(request, "mapping_error", str(e)) - def _generate_oidc_session_token( - self, - state: str, - nonce: str, - client_redirect_url: str, - ui_auth_session_id: Optional[str], - duration_in_ms: int = (60 * 60 * 1000), - ) -> str: - """Generates a signed token storing data about an OIDC session. - - When Synapse initiates an authorization flow, it creates a random state - and a random nonce. Those parameters are given to the provider and - should be verified when the client comes back from the provider. - It is also used to store the client_redirect_url, which is used to - complete the SSO login flow. - - Args: - state: The ``state`` parameter passed to the OIDC provider. - nonce: The ``nonce`` parameter passed to the OIDC provider. - client_redirect_url: The URL the client gave when it initiated the - flow. - ui_auth_session_id: The session ID of the ongoing UI Auth (or - None if this is a login). - duration_in_ms: An optional duration for the token in milliseconds. - Defaults to an hour. - - Returns: - A signed macaroon token with the session information. - """ - macaroon = pymacaroons.Macaroon( - location=self._server_name, identifier="key", key=self._macaroon_secret_key, - ) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("type = session") - macaroon.add_first_party_caveat("state = %s" % (state,)) - macaroon.add_first_party_caveat("nonce = %s" % (nonce,)) - macaroon.add_first_party_caveat( - "client_redirect_url = %s" % (client_redirect_url,) - ) - if ui_auth_session_id: - macaroon.add_first_party_caveat( - "ui_auth_session_id = %s" % (ui_auth_session_id,) - ) - now = self.clock.time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - - return macaroon.serialize() - - def _verify_oidc_session_token( - self, session: bytes, state: str - ) -> Tuple[str, str, Optional[str]]: - """Verifies and extract an OIDC session token. - - This verifies that a given session token was issued by this homeserver - and extract the nonce and client_redirect_url caveats. - - Args: - session: The session token to verify - state: The state the OIDC provider gave back - - Returns: - The nonce, client_redirect_url, and ui_auth_session_id for this session - """ - macaroon = pymacaroons.Macaroon.deserialize(session) - - v = pymacaroons.Verifier() - v.satisfy_exact("gen = 1") - v.satisfy_exact("type = session") - v.satisfy_exact("state = %s" % (state,)) - v.satisfy_general(lambda c: c.startswith("nonce = ")) - v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) - # Sometimes there's a UI auth session ID, it seems to be OK to attempt - # to always satisfy this. - v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) - v.satisfy_general(self._verify_expiry) - - v.verify(macaroon, self._macaroon_secret_key) - - # Extract the `nonce`, `client_redirect_url`, and maybe the - # `ui_auth_session_id` from the token. - nonce = self._get_value_from_macaroon(macaroon, "nonce") - client_redirect_url = self._get_value_from_macaroon( - macaroon, "client_redirect_url" - ) - try: - ui_auth_session_id = self._get_value_from_macaroon( - macaroon, "ui_auth_session_id" - ) # type: Optional[str] - except ValueError: - ui_auth_session_id = None - - return nonce, client_redirect_url, ui_auth_session_id - - def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: - """Extracts a caveat value from a macaroon token. - - Args: - macaroon: the token - key: the key of the caveat to extract - - Returns: - The extracted value - - Raises: - Exception: if the caveat was not in the macaroon - """ - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - def _verify_expiry(self, caveat: str) -> bool: - prefix = "time < " - if not caveat.startswith(prefix): - return False - expiry = int(caveat[len(prefix) :]) - now = self.clock.time_msec() - return now < expiry - async def _complete_oidc_login( self, userinfo: UserInfo, @@ -901,8 +781,8 @@ class OidcHandler(BaseHandler): # and attempt to match it. attributes = await oidc_response_to_user_attributes(failures=0) - user_id = UserID(attributes.localpart, self.server_name).to_string() - users = await self.store.get_users_by_id_case_insensitive(user_id) + user_id = UserID(attributes.localpart, self._server_name).to_string() + users = await self._store.get_users_by_id_case_insensitive(user_id) if users: # If an existing matrix ID is returned, then use it. if len(users) == 1: @@ -954,6 +834,148 @@ class OidcHandler(BaseHandler): return str(remote_user_id) +class OidcSessionTokenGenerator: + """Methods for generating and checking OIDC Session cookies.""" + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._server_name = hs.hostname + self._macaroon_secret_key = hs.config.key.macaroon_secret_key + + def generate_oidc_session_token( + self, + state: str, + session_data: "OidcSessionData", + duration_in_ms: int = (60 * 60 * 1000), + ) -> str: + """Generates a signed token storing data about an OIDC session. + + When Synapse initiates an authorization flow, it creates a random state + and a random nonce. Those parameters are given to the provider and + should be verified when the client comes back from the provider. + It is also used to store the client_redirect_url, which is used to + complete the SSO login flow. + + Args: + state: The ``state`` parameter passed to the OIDC provider. + session_data: data to include in the session token. + duration_in_ms: An optional duration for the token in milliseconds. + Defaults to an hour. + + Returns: + A signed macaroon token with the session information. + """ + macaroon = pymacaroons.Macaroon( + location=self._server_name, identifier="key", key=self._macaroon_secret_key, + ) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = session") + macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,)) + macaroon.add_first_party_caveat( + "client_redirect_url = %s" % (session_data.client_redirect_url,) + ) + if session_data.ui_auth_session_id: + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) + ) + now = self._clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + + return macaroon.serialize() + + def verify_oidc_session_token( + self, session: bytes, state: str + ) -> "OidcSessionData": + """Verifies and extract an OIDC session token. + + This verifies that a given session token was issued by this homeserver + and extract the nonce and client_redirect_url caveats. + + Args: + session: The session token to verify + state: The state the OIDC provider gave back + + Returns: + The data extracted from the session cookie + """ + macaroon = pymacaroons.Macaroon.deserialize(session) + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = session") + v.satisfy_exact("state = %s" % (state,)) + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + # Sometimes there's a UI auth session ID, it seems to be OK to attempt + # to always satisfy this. + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + v.satisfy_general(self._verify_expiry) + + v.verify(macaroon, self._macaroon_secret_key) + + # Extract the `nonce`, `client_redirect_url`, and maybe the + # `ui_auth_session_id` from the token. + nonce = self._get_value_from_macaroon(macaroon, "nonce") + client_redirect_url = self._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + try: + ui_auth_session_id = self._get_value_from_macaroon( + macaroon, "ui_auth_session_id" + ) # type: Optional[str] + except ValueError: + ui_auth_session_id = None + + return OidcSessionData( + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=ui_auth_session_id, + ) + + def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + Exception: if the caveat was not in the macaroon + """ + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + def _verify_expiry(self, caveat: str) -> bool: + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + now = self._clock.time_msec() + return now < expiry + + +@attr.s(frozen=True, slots=True) +class OidcSessionData: + """The attributes which are stored in a OIDC session cookie""" + + # The `nonce` parameter passed to the OIDC provider. + nonce = attr.ib(type=str) + + # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) + client_redirect_url = attr.ib(type=str) + + # The session ID of the ongoing UI Auth (None if this is a login) + ui_auth_session_id = attr.ib(type=Optional[str], default=None) + + UserAttributeDict = TypedDict( "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 4ce0f74f22..2abd7a83b5 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -14,7 +14,7 @@ # limitations under the License. import json import re -from typing import Dict +from typing import Dict, Optional from urllib.parse import parse_qs, urlencode, urlparse from mock import ANY, Mock, patch @@ -349,9 +349,13 @@ class OidcHandlerTestCase(HomeserverTestCase): cookie = args[1] macaroon = pymacaroons.Macaroon.deserialize(cookie) - state = self.handler._get_value_from_macaroon(macaroon, "state") - nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") - redirect = self.handler._get_value_from_macaroon( + state = self.handler._token_generator._get_value_from_macaroon( + macaroon, "state" + ) + nonce = self.handler._token_generator._get_value_from_macaroon( + macaroon, "nonce" + ) + redirect = self.handler._token_generator._get_value_from_macaroon( macaroon, "client_redirect_url" ) @@ -411,12 +415,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - session = self.handler._generate_oidc_session_token( - state=state, - nonce=nonce, - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, - ) + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) request = _build_callback_request( code, state, session, user_agent=user_agent, ip_address=ip_address ) @@ -500,11 +499,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("invalid_session") # Mismatching session - session = self.handler._generate_oidc_session_token( - state="state", - nonce="nonce", - client_redirect_url="http://client/redirect", - ui_auth_session_id=None, + session = self._generate_oidc_session_token( + state="state", nonce="nonce", client_redirect_url="http://client/redirect", ) request.args = {} request.args[b"state"] = [b"mismatching state"] @@ -623,11 +619,8 @@ class OidcHandlerTestCase(HomeserverTestCase): state = "state" client_redirect_url = "http://client/redirect" - session = self.handler._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, + session = self._generate_oidc_session_token( + state=state, nonce="nonce", client_redirect_url=client_redirect_url, ) request = _build_callback_request("code", state, session) @@ -841,6 +834,24 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.assertRenderedError("mapping_error", "localpart is invalid: ") + def _generate_oidc_session_token( + self, + state: str, + nonce: str, + client_redirect_url: str, + ui_auth_session_id: Optional[str] = None, + ) -> str: + from synapse.handlers.oidc_handler import OidcSessionData + + return self.handler._token_generator.generate_oidc_session_token( + state=state, + session_data=OidcSessionData( + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=ui_auth_session_id, + ), + ) + class UsernamePickerTestCase(HomeserverTestCase): if not HAS_OIDC: @@ -965,17 +976,19 @@ async def _make_callback_with_userinfo( userinfo: the OIDC userinfo dict client_redirect_url: the URL to redirect to on success. """ + from synapse.handlers.oidc_handler import OidcSessionData + handler = hs.get_oidc_handler() handler._exchange_code = simple_async_mock(return_value={}) handler._parse_id_token = simple_async_mock(return_value=userinfo) handler._fetch_userinfo = simple_async_mock(return_value=userinfo) state = "state" - session = handler._generate_oidc_session_token( + session = handler._token_generator.generate_oidc_session_token( state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, + session_data=OidcSessionData( + nonce="nonce", client_redirect_url=client_redirect_url, + ), ) request = _build_callback_request("code", state, session) -- cgit 1.5.1 From 98a64b7f7f256b7afd4a1d735cb32d099e44831a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Jan 2021 07:05:16 -0500 Subject: Add basic domain validation for `DomainSpecificString.is_valid`. (#9071) This checks that the domain given to `DomainSpecificString.is_valid` (e.g. `UserID`, `RoomAlias`, etc.) is of a valid form. Previously some validation was done on the localpart (e.g. the sigil), but not the domain portion. --- changelog.d/9071.bugfix | 1 + synapse/types.py | 8 +++++++- tests/test_types.py | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9071.bugfix (limited to 'tests') diff --git a/changelog.d/9071.bugfix b/changelog.d/9071.bugfix new file mode 100644 index 0000000000..0201271f84 --- /dev/null +++ b/changelog.d/9071.bugfix @@ -0,0 +1 @@ +Fix "Failed to send request" errors when a client provides an invalid room alias. diff --git a/synapse/types.py b/synapse/types.py index c7d4e95809..20a43d05bf 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,6 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +from synapse.http.endpoint import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService @@ -257,8 +258,13 @@ class DomainSpecificString( @classmethod def is_valid(cls: Type[DS], s: str) -> bool: + """Parses the input string and attempts to ensure it is valid.""" try: - cls.from_string(s) + obj = cls.from_string(s) + # Apply additional validation to the domain. This is only done + # during is_valid (and not part of from_string) since it is + # possible for invalid data to exist in room-state, etc. + parse_and_validate_server_name(obj.domain) return True except Exception: return False diff --git a/tests/test_types.py b/tests/test_types.py index 480bea1bdc..acdeea7a09 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -58,6 +58,10 @@ class RoomAliasTestCase(unittest.HomeserverTestCase): self.assertEquals(room.to_string(), "#channel:my.domain") + def test_validate(self): + id_string = "#test:domain,test" + self.assertFalse(RoomAlias.is_valid(id_string)) + class GroupIDTestCase(unittest.TestCase): def test_parse(self): -- cgit 1.5.1 From 233c8b9fce99616631614d176de4612c29277884 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 13 Jan 2021 20:21:55 +0000 Subject: Add a test for UI-Auth-via-SSO (#9082) * Add complete test for UI-Auth-via-SSO. * review comments --- changelog.d/9082.feature | 1 + tests/rest/client/v1/utils.py | 196 ++++++++++++++++++++++++++------ tests/rest/client/v2_alpha/test_auth.py | 40 ++++++- tests/server.py | 32 +++++- 4 files changed, 227 insertions(+), 42 deletions(-) create mode 100644 changelog.d/9082.feature (limited to 'tests') diff --git a/changelog.d/9082.feature b/changelog.d/9082.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9082.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 81b7f84360..85d1709ead 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -2,7 +2,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,8 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Optional +from html.parser import HTMLParser +from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple from mock import patch @@ -32,7 +33,7 @@ from twisted.web.server import Site from synapse.api.constants import Membership from synapse.types import JsonDict -from tests.server import FakeSite, make_request +from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse @@ -362,34 +363,94 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" + channel = self.auth_via_oidc(remote_user_id, client_redirect_url) - # first hit the redirect url (which will issue a cookie and state) + # expect a confirmation page + assert channel.code == 200 + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. channel = make_request( self.hs.get_reactor(), self.site, - "GET", - "/login/sso/redirect?redirectUrl=" + client_redirect_url, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) - # that will redirect to the OIDC IdP, but we skip that and go straight + assert channel.code == 200 + return channel.json_body + + def auth_via_oidc( + self, + remote_user_id: str, + client_redirect_url: Optional[str] = None, + ui_auth_session_id: Optional[str] = None, + ) -> FakeChannel: + """Perform an OIDC authentication flow via a mock OIDC provider. + + This can be used for either login or user-interactive auth. + + Starts by making a request to the relevant synapse redirect endpoint, which is + expected to serve a 302 to the OIDC provider. We then make a request to the + OIDC callback endpoint, intercepting the HTTP requests that will get sent back + to the OIDC provider. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + + Args: + remote_user_id: the remote id that the OIDC provider should present + client_redirect_url: for a login flow, the client redirect URL to pass to + the login redirect endpoint + ui_auth_session_id: if set, we will perform a UI Auth flow. The session id + of the UI auth. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + Note that the response code may be a 200, 302 or 400 depending on how things + went. + """ + + cookies = {} + + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + + # we now have a URI for the OIDC IdP, but we skip that and go straight # back to synapse's OIDC callback resource. However, we do need the "state" - # param that synapse passes to the IdP via query params, and the cookie that - # synapse passes to the client. - assert channel.code == 302 - oauth_uri = channel.headers.getRawHeaders("Location")[0] - params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) - redirect_uri = "%s?%s" % ( + # param that synapse passes to the IdP via query params, as well as the cookie + # that synapse passes to the client. + + oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + "unexpected SSO URI " + oauth_uri_path + ) + params = urllib.parse.parse_qs(oauth_uri_qs) + callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), ) - cookies = {} - for h in channel.headers.getRawHeaders("Set-Cookie"): - parts = h.split(";") - k, v = parts[0].split("=", maxsplit=1) - cookies[k] = v # before we hit the callback uri, stub out some methods in the http client so # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. expected_requests = [ # first we get a hit to the token endpoint, which we tell to return @@ -413,34 +474,97 @@ class RestHelper: self.hs.get_reactor(), self.site, "GET", - redirect_uri, + callback_uri, custom_headers=[ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() ], ) + return channel - # expect a confirmation page - assert channel.code == 200 + def initiate_sso_login( + self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the login-via-sso redirect endpoint, and return the target - # fish the matrix login token out of the body of the confirmation page - m = re.search( - 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), - channel.result["body"].decode("utf-8"), - ) - assert m - login_token = m.group(1) + Assumes that exactly one SSO provider has been configured. Requires the login + servlet to be mounted. - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + Args: + client_redirect_url: the client redirect URL to pass to the login redirect + endpoint + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets redirected to (ie, the SSO server) + """ + params = {} + if client_redirect_url: + params["redirectUrl"] = client_redirect_url + + # hit the redirect url (which will issue a cookie and state) channel = make_request( self.hs.get_reactor(), self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, + "GET", + "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), ) - assert channel.code == 200 - return channel.json_body + + assert channel.code == 302 + channel.extract_cookies(cookies) + return channel.headers.getRawHeaders("Location")[0] + + def initiate_sso_ui_auth( + self, ui_auth_session_id: str, cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the ui-auth-via-sso endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the + AuthRestServlet to be mounted. + + Args: + ui_auth_session_id: the session id of the UI auth + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets linked to (ie, the SSO server) + """ + sso_redirect_endpoint = ( + "/_matrix/client/r0/auth/m.login.sso/fallback/web?" + + urllib.parse.urlencode({"session": ui_auth_session_id}) + ) + # hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint + ) + # that should serve a confirmation page + assert channel.code == 200, channel.text_body + channel.extract_cookies(cookies) + + # parse the confirmation page to fish out the link. + class ConfirmationPageParser(HTMLParser): + def __init__(self): + super().__init__() + + self.links = [] # type: List[str] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + + def error(_, message): + raise AssertionError(message) + + p = ConfirmationPageParser() + p.feed(channel.text_body) + p.close() + assert len(p.links) == 1, "not exactly one link in confirmation page" + oauth_uri = p.links[0] + return oauth_uri # an 'oidc_config' suitable for login_via_oidc. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index bb91e0c331..5f6ca23b06 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector +# Copyright 2020-2021 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Union from twisted.internet.defer import succeed @@ -384,6 +384,44 @@ class UIAuthTests(unittest.HomeserverTestCase): # Note that *no auth* information is provided, not even a session iD! self.delete_device(self.user_tok, self.device_id, 200) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_via_sso(self): + """Test a successful UI Auth flow via SSO + + This includes: + * hitting the UIA SSO redirect endpoint + * checking it serves a confirmation page which links to the OIDC provider + * calling back to the synapse oidc callback + * checking that the original operation succeeds + """ + + # log the user in + remote_user_id = UserID.from_string(self.user).localpart + login_resp = self.helper.login_via_oidc(remote_user_id) + self.assertEqual(login_resp["user_id"], self.user) + + # initiate a UI Auth process by attempting to delete the device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # check that SSO is offered + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + + # run the UIA-via-SSO flow + session_id = channel.json_body["session"] + channel = self.helper.auth_via_oidc( + remote_user_id=remote_user_id, ui_auth_session_id=session_id + ) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + + # and now the delete request should succeed. + self.delete_device( + self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, + ) + @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self): diff --git a/tests/server.py b/tests/server.py index 7d1ad362c4..5a1b66270f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Iterable, Optional, Tuple, Union +from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -51,9 +51,21 @@ class FakeChannel: @property def json_body(self): - if not self.result: - raise Exception("No result yet.") - return json.loads(self.result["body"].decode("utf8")) + return json.loads(self.text_body) + + @property + def text_body(self) -> str: + """The body of the result, utf-8-decoded. + + Raises an exception if the request has not yet completed. + """ + if not self.is_finished: + raise Exception("Request not yet completed") + return self.result["body"].decode("utf8") + + def is_finished(self) -> bool: + """check if the response has been completely received""" + return self.result.get("done", False) @property def code(self): @@ -124,7 +136,7 @@ class FakeChannel: self._reactor.run() x = 0 - while not self.result.get("done"): + while not self.is_finished(): # If there's a producer, tell it to resume producing so we get content if self._producer: self._producer.resumeProducing() @@ -136,6 +148,16 @@ class FakeChannel: self._reactor.advance(0.1) + def extract_cookies(self, cookies: MutableMapping[str, str]) -> None: + """Process the contents of any Set-Cookie headers in the response + + Any cookines found are added to the given dict + """ + for h in self.headers.getRawHeaders("Set-Cookie"): + parts = h.split(";") + k, v = parts[0].split("=", maxsplit=1) + cookies[k] = v + class FakeSite: """ -- cgit 1.5.1 From 26d10331e5fa33507b680b45c44641e660a3adeb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 13 Jan 2021 12:35:23 +0000 Subject: Add a test for wrong user returned by SSO --- tests/rest/client/v2_alpha/test_auth.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 5f6ca23b06..50630106ad 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -457,3 +457,30 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertIn({"stages": ["m.login.password"]}, flows) self.assertIn({"stages": ["m.login.sso"]}, flows) self.assertEqual(len(flows), 2) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_fails_for_incorrect_sso_user(self): + """If the user tries to authenticate with the wrong SSO user, they get an error + """ + # log the user in + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + # start a UI Auth flow by attempting to delete a device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + session_id = channel.json_body["session"] + + # do the OIDC auth, but auth as the wrong user + channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id) + + # that should return a failure message + self.assertSubstring("We were unable to validate", channel.text_body) + + # ... and the delete op should now fail with a 403 + self.delete_device( + self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + ) -- cgit 1.5.1 From 21a296cd5ac9c450a6e8896e25f0a4afad1c3774 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 14 Jan 2021 13:29:17 +0000 Subject: Split OidcProvider out of OidcHandler (#9107) The idea here is that we will have an instance of OidcProvider for each configured IdP, with OidcHandler just doing the marshalling of them. For now it's still hardcoded with a single provider. --- changelog.d/9107.feature | 1 + synapse/app/homeserver.py | 1 - synapse/handlers/oidc_handler.py | 246 +++++++++++++++++++++++---------------- tests/handlers/test_oidc.py | 93 ++++++++------- 4 files changed, 197 insertions(+), 144 deletions(-) create mode 100644 changelog.d/9107.feature (limited to 'tests') diff --git a/changelog.d/9107.feature b/changelog.d/9107.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9107.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index cbecf23be6..57a2f5237c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -429,7 +429,6 @@ def setup(config_options): oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() - await oidc.load_jwks() await _base.start(hs, config.listeners) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 84754e5c9c..d6347bb1b8 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -35,6 +35,7 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError +from synapse.config.oidc_config import OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -70,6 +71,131 @@ JWK = Dict[str, str] JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class OidcHandler: + """Handles requests related to the OpenID Connect login flow. + """ + + def __init__(self, hs: "HomeServer"): + self._sso_handler = hs.get_sso_handler() + + provider_conf = hs.config.oidc.oidc_provider + # we should not have been instantiated if there is no configured provider. + assert provider_conf is not None + + self._token_generator = OidcSessionTokenGenerator(hs) + + self._provider = OidcProvider(hs, self._token_generator, provider_conf) + + async def load_metadata(self) -> None: + """Validate the config and load the metadata from the remote endpoint. + + Called at startup to ensure we have everything we need. + """ + await self._provider.load_metadata() + await self._provider.load_jwks() + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._sso_handler.render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + + Once we know the session is legit, we then delegate to the OIDC Provider + implementation, which will exchange the code with the provider and complete the + login/authentication. + + Args: + request: the incoming request from the browser. + """ + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + # error response from the auth server. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthError + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + if error != "access_denied": + logger.error("Error from the OIDC provider: %s %s", error, description) + + self._sso_handler.render_error(request, error, description) + return + + # otherwise, it is presumably a successful response. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2 + + # Fetch the session cookie + session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] + if session is None: + logger.info("No session cookie found") + self._sso_handler.render_error( + request, "missing_session", "No session cookie found" + ) + return + + # Remove the cookie. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing it early avoids spamming the provider with token requests. + request.addCookie( + SESSION_COOKIE_NAME, + b"", + path="/_synapse/oidc", + expires="Thu, Jan 01 1970 00:00:00 UTC", + httpOnly=True, + sameSite="lax", + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("State parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "State parameter is missing" + ) + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + session_data = self._token_generator.verify_oidc_session_token( + session, state + ) + except MacaroonDeserializationException as e: + logger.exception("Invalid session") + self._sso_handler.render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session") + self._sso_handler.render_error(request, "mismatching_session", str(e)) + return + + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "Code parameter is missing" + ) + return + + code = request.args[b"code"][0].decode() + + await self._provider.handle_oidc_callback(request, session_data, code) + + class OidcError(Exception): """Used to catch errors when calling the token_endpoint """ @@ -84,21 +210,25 @@ class OidcError(Exception): return self.error -class OidcHandler: - """Handles requests related to the OpenID Connect login flow. +class OidcProvider: + """Wraps the config for a single OIDC IdentityProvider + + Provides methods for handling redirect requests and callbacks via that particular + IdP. """ - def __init__(self, hs: "HomeServer"): + def __init__( + self, + hs: "HomeServer", + token_generator: "OidcSessionTokenGenerator", + provider: OidcProviderConfig, + ): self._store = hs.get_datastore() - self._token_generator = OidcSessionTokenGenerator(hs) + self._token_generator = token_generator self._callback_url = hs.config.oidc_callback_url # type: str - provider = hs.config.oidc.oidc_provider - # we should not have been instantiated if there is no configured provider. - assert provider is not None - self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( @@ -552,22 +682,16 @@ class OidcHandler: nonce=nonce, ) - async def handle_oidc_callback(self, request: SynapseRequest) -> None: + async def handle_oidc_callback( + self, request: SynapseRequest, session_data: "OidcSessionData", code: str + ) -> None: """Handle an incoming request to /_synapse/oidc/callback - Since we might want to display OIDC-related errors in a user-friendly - way, we don't raise SynapseError from here. Instead, we call - ``self._sso_handler.render_error`` which displays an HTML page for the error. + By this time we have already validated the session on the synapse side, and + now need to do the provider-specific operations. This includes: - Most of the OpenID Connect logic happens here: - - - first, we check if there was any error returned by the provider and - display it - - then we fetch the session cookie, decode and verify it - - the ``state`` query parameter should match with the one stored in the - session cookie - - once we known this session is legit, exchange the code with the - provider using the ``token_endpoint`` (see ``_exchange_code``) + - exchange the code with the provider using the ``token_endpoint`` (see + ``_exchange_code``) - once we have the token, use it to either extract the UserInfo from the ``id_token`` (``_parse_id_token``), or use the ``access_token`` to fetch UserInfo from the ``userinfo_endpoint`` @@ -577,86 +701,12 @@ class OidcHandler: Args: request: the incoming request from the browser. + session_data: the session data, extracted from our cookie + code: The authorization code we got from the callback. """ - - # The provider might redirect with an error. - # In that case, just display it as-is. - if b"error" in request.args: - # error response from the auth server. see: - # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 - # https://openid.net/specs/openid-connect-core-1_0.html#AuthError - error = request.args[b"error"][0].decode() - description = request.args.get(b"error_description", [b""])[0].decode() - - # Most of the errors returned by the provider could be due by - # either the provider misbehaving or Synapse being misconfigured. - # The only exception of that is "access_denied", where the user - # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) - - self._sso_handler.render_error(request, error, description) - return - - # otherwise, it is presumably a successful response. see: - # https://tools.ietf.org/html/rfc6749#section-4.1.2 - - # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] - if session is None: - logger.info("No session cookie found") - self._sso_handler.render_error( - request, "missing_session", "No session cookie found" - ) - return - - # Remove the cookie. There is a good chance that if the callback failed - # once, it will fail next time and the code will already be exchanged. - # Removing it early avoids spamming the provider with token requests. - request.addCookie( - SESSION_COOKIE_NAME, - b"", - path="/_synapse/oidc", - expires="Thu, Jan 01 1970 00:00:00 UTC", - httpOnly=True, - sameSite="lax", - ) - - # Check for the state query parameter - if b"state" not in request.args: - logger.info("State parameter is missing") - self._sso_handler.render_error( - request, "invalid_request", "State parameter is missing" - ) - return - - state = request.args[b"state"][0].decode() - - # Deserialize the session token and verify it. - try: - session_data = self._token_generator.verify_oidc_session_token( - session, state - ) - except MacaroonDeserializationException as e: - logger.exception("Invalid session") - self._sso_handler.render_error(request, "invalid_session", str(e)) - return - except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") - self._sso_handler.render_error(request, "mismatching_session", str(e)) - return - # Exchange the code with the provider - if b"code" not in request.args: - logger.info("Code parameter is missing") - self._sso_handler.render_error( - request, "invalid_request", "Code parameter is missing" - ) - return - - logger.debug("Exchanging code") - code = request.args[b"code"][0].decode() try: + logger.debug("Exchanging code") token = await self._exchange_code(code) except OidcError as e: logger.exception("Could not exchange code") diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 2abd7a83b5..5d338bea87 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() + self.provider = self.handler._provider sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def metadata_edit(self, values): - return patch.dict(self.handler._provider_metadata, values) + return patch.dict(self.provider._provider_metadata, values) def assertRenderedError(self, error, error_description=None): + self.render_error.assert_called_once() args = self.render_error.call_args[0] self.assertEqual(args[1], error) if error_description is not None: @@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" - self.assertEqual(self.handler._callback_url, CALLBACK_URL) - self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) - self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + self.assertEqual(self.provider._callback_url, CALLBACK_URL) + self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = self.get_success(self.handler.load_metadata()) + metadata = self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = self.get_success(self.handler.load_jwks()) + jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cached… self.http_client.reset_mock() - self.get_success(self.handler.load_jwks()) + self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_not_called() # …unless forced self.http_client.reset_mock() - self.get_success(self.handler.load_jwks(force=True)) + self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - self.get_failure(self.handler.load_jwks(force=True), RuntimeError) + self.get_failure(self.provider.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used - self.handler._scopes = [] # not asking the openid scope + self.provider._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = self.get_success(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @override_config({"oidc_config": COMMON_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" - h = self.handler + h = self.provider # Default test config does not throw h._validate_metadata() @@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw - self.handler._validate_metadata() + self.provider._validate_metadata() def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) url = self.get_success( - self.handler.handle_redirect_request(req, b"http://client/redirect") + self.provider.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) @@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # ensure that we are correctly testing the fallback when "get_extra_attributes" # is not implemented. - mapping_provider = self.handler._user_mapping_provider + mapping_provider = self.provider._user_mapping_provider with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes @@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.assert_called_once_with( expected_user_id, request, client_redirect_url, None, ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._fetch_userinfo.assert_not_called() + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors with patch.object( - self.handler, + self.provider, "_remote_id_from_userinfo", new=Mock(side_effect=MappingException()), ): @@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.handler._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") auth_handler.complete_sso_login.reset_mock() - self.handler._exchange_code.reset_mock() - self.handler._parse_id_token.reset_mock() - self.handler._fetch_userinfo.reset_mock() + self.provider._exchange_code.reset_mock() + self.provider._parse_id_token.reset_mock() + self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.handler._scopes = [] # do not ask the "openid" scope + self.provider._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( expected_user_id, request, client_redirect_url, None, ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_not_called() - self.handler._fetch_userinfo.assert_called_once_with(token) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_not_called() + self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() # Handle userinfo fetching error - self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc_handler import OidcError - self.handler._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) @@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = self.get_success(self.handler._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) from synapse.handlers.oidc_handler import OidcError - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body @@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field @@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @override_config( @@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "foo", "phone": "1234567", } - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -979,9 +981,10 @@ async def _make_callback_with_userinfo( from synapse.handlers.oidc_handler import OidcSessionData handler = hs.get_oidc_handler() - handler._exchange_code = simple_async_mock(return_value={}) - handler._parse_id_token = simple_async_mock(return_value=userinfo) - handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider = handler._provider + provider._exchange_code = simple_async_mock(return_value={}) + provider._parse_id_token = simple_async_mock(return_value=userinfo) + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) state = "state" session = handler._token_generator.generate_oidc_session_token( -- cgit 1.5.1 From 7036e24e98fc21855c34876d7024015470721bbe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jan 2021 15:18:27 +0000 Subject: Add background update for add chain cover index (#9029) --- changelog.d/8868.misc | 2 +- changelog.d/9029.misc | 1 + scripts/synapse_port_db | 2 +- synapse/storage/databases/main/events.py | 50 ++++-- .../storage/databases/main/events_bg_updates.py | 192 ++++++++++++++++++++- .../main/schema/delta/59/06chain_cover_index.sql | 17 ++ tests/storage/test_event_chain.py | 114 ++++++++++++ 7 files changed, 360 insertions(+), 18 deletions(-) create mode 100644 changelog.d/9029.misc create mode 100644 synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql (limited to 'tests') diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc index 1a11e30944..346741d982 100644 --- a/changelog.d/8868.misc +++ b/changelog.d/8868.misc @@ -1 +1 @@ -Improve efficiency of large state resolutions for new rooms. +Improve efficiency of large state resolutions. diff --git a/changelog.d/9029.misc b/changelog.d/9029.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9029.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22dd169bfb..69bf9110a6 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public"], + "rooms": ["is_public", "has_auth_chain_index"], "event_edges": ["is_state"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 186f064036..e0fbcc58cf 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -466,9 +466,6 @@ class PersistEventsStore: if not state_events: return - # Map from event ID to chain ID/sequence number. - chain_map = {} # type: Dict[str, Tuple[int, int]] - # We need to know the type/state_key and auth events of the events we're # calculating chain IDs for. We don't rely on having the full Event # instances as we'll potentially be pulling more events from the DB and @@ -479,9 +476,33 @@ class PersistEventsStore: event_to_auth_chain = { e.event_id: e.auth_event_ids() for e in state_events.values() } + event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} + + self._add_chain_cover_index( + txn, event_to_room_id, event_to_types, event_to_auth_chain + ) + + def _add_chain_cover_index( + self, + txn, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + ) -> None: + """Calculate the chain cover index for the given events. + + Args: + event_to_room_id: Event ID to the room ID of the event + event_to_types: Event ID to type and state_key of the event + event_to_auth_chain: Event ID to list of auth event IDs of the + event (events with no auth events can be excluded). + """ + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] # Set of event IDs to calculate chain ID/seq numbers for. - events_to_calc_chain_id_for = set(state_events) + events_to_calc_chain_id_for = set(event_to_room_id) # We check if there are any events that need to be handled in the rooms # we're looking at. These should just be out of band memberships, where @@ -491,7 +512,7 @@ class PersistEventsStore: table="event_auth_chain_to_calculate", keyvalues={}, column="room_id", - iterable={e.room_id for e in state_events.values()}, + iterable=set(event_to_room_id.values()), retcols=("event_id", "type", "state_key"), ) for row in rows: @@ -582,16 +603,17 @@ class PersistEventsStore: # the list of events to calculate chain IDs for next time # around. (Otherwise we will have already added it to the # table). - event = state_events.get(event_id) - if event: + room_id = event_to_room_id.get(event_id) + if room_id: + e_type, state_key = event_to_types[event_id] self.db_pool.simple_insert_txn( txn, table="event_auth_chain_to_calculate", values={ - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, + "event_id": event_id, + "room_id": room_id, + "type": e_type, + "state_key": state_key, }, ) @@ -617,7 +639,7 @@ class PersistEventsStore: events_to_calc_chain_id_for, event_to_auth_chain ): existing_chain_id = None - for auth_id in event_to_auth_chain[event_id]: + for auth_id in event_to_auth_chain.get(event_id, []): if event_to_types.get(event_id) == event_to_types.get(auth_id): existing_chain_id = chain_map[auth_id] break @@ -730,11 +752,11 @@ class PersistEventsStore: # auth events (A, B) to check if B is reachable from A. reduction = { a_id - for a_id in event_to_auth_chain[event_id] + for a_id in event_to_auth_chain.get(event_id, []) if chain_map[a_id][0] != chain_id } for start_auth_id, end_auth_id in itertools.permutations( - event_to_auth_chain[event_id], r=2, + event_to_auth_chain.get(event_id, []), r=2, ): if chain_links.exists_path_from( chain_map[start_auth_id], chain_map[end_auth_id] diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 7e4b175d08..90a40a92b4 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -14,13 +14,13 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventContentFields from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -108,6 +108,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): "rejected_events_metadata", self._rejected_events_metadata, ) + self.db_pool.updates.register_background_update_handler( + "chain_cover", self._chain_cover_index, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -706,3 +710,187 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return len(results) + + async def _chain_cover_index(self, progress: dict, batch_size: int) -> int: + """A background updates that iterates over all rooms and generates the + chain cover index for them. + """ + + current_room_id = progress.get("current_room_id", "") + + # Have we finished processing the current room. + finished = progress.get("finished", True) + + # Where we've processed up to in the room, defaults to the start of the + # room. + last_depth = progress.get("last_depth", -1) + last_stream = progress.get("last_stream", -1) + + # Have we set the `has_auth_chain_index` for the room yet. + has_set_room_has_chain_index = progress.get( + "has_set_room_has_chain_index", False + ) + + if finished: + # If we've finished with the previous room (or its our first + # iteration) we move on to the next room. + + def _get_next_room(txn: Cursor) -> Optional[str]: + sql = """ + SELECT room_id FROM rooms + WHERE room_id > ? + AND ( + NOT has_auth_chain_index + OR has_auth_chain_index IS NULL + ) + ORDER BY room_id + LIMIT 1 + """ + txn.execute(sql, (current_room_id,)) + row = txn.fetchone() + if row: + return row[0] + + return None + + current_room_id = await self.db_pool.runInteraction( + "_chain_cover_index", _get_next_room + ) + if not current_room_id: + await self.db_pool.updates._end_background_update("chain_cover") + return 0 + + logger.debug("Adding chain cover to %s", current_room_id) + + def _calculate_auth_chain( + txn: Cursor, last_depth: int, last_stream: int + ) -> Tuple[int, int, int]: + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE events.room_id = ? + AND event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + ORDER BY topological_ordering, stream_ordering + LIMIT ? + """ % { + "tuple_cmp": tuple_clause, + } + + args = [current_room_id] + args.extend(tuple_args) + args.append(batch_size) + + txn.execute(sql, args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: current_room_id for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append( + row["auth_id"] + ) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + self.hs.get_datastores().persist_events._add_chain_cover_index( + txn, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + return new_last_depth, new_last_stream, count + + last_depth, last_stream, count = await self.db_pool.runInteraction( + "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + ) + + total_rows_processed = count + + if count < batch_size and not has_set_room_has_chain_index: + # If we've done all the events in the room we flip the + # `has_auth_chain_index` in the DB. Note that its possible for + # further events to be persisted between the above and setting the + # flag without having the chain cover calculated for them. This is + # fine as a) the code gracefully handles these cases and b) we'll + # calculate them below. + + await self.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": current_room_id}, + updatevalues={"has_auth_chain_index": True}, + desc="_chain_cover_index", + ) + has_set_room_has_chain_index = True + + # Handle any events that might have raced with us flipping the + # bit above. + last_depth, last_stream, count = await self.db_pool.runInteraction( + "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + ) + + total_rows_processed += count + + # Note that at this point its technically possible that more events + # than our `batch_size` have been persisted without their chain + # cover, so we need to continue processing this room if the last + # count returned was equal to the `batch_size`. + + if count < batch_size: + # We've finished calculating the index for this room, move on to the + # next room. + await self.db_pool.updates._background_update_progress( + "chain_cover", {"current_room_id": current_room_id, "finished": True}, + ) + else: + # We still have outstanding events to calculate the index for. + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + "has_auth_chain_index": has_set_room_has_chain_index, + "finished": False, + }, + ) + + return total_rows_processed diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql new file mode 100644 index 0000000000..fe3dca71dd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5906, 'chain_cover', '{}', 'rejected_events_metadata'); diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 83c377824b..ff67a73749 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -20,7 +20,10 @@ from twisted.trial import unittest from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events import _LinkMap +from synapse.types import create_requester from tests.unittest import HomeserverTestCase @@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase): self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) + + +class EventChainBackgroundUpdateTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_background_update(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + user_id = self.register_user("foo", "pass") + token = self.login("foo", "pass") + room_id = self.helper.create_room_as(user_id, tok=token) + requester = create_requester(user_id) + + store = self.hs.get_datastore() + + # Mark the room as not having a chain cover index + self.get_success( + store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + desc="test", + ) + ) + + # Create a fork in the DAG with different events. + event_handler = self.hs.get_event_creation_handler() + latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "some_state_type", + "state_key": "", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + state1 = list(self.get_success(context.get_current_state_ids()).values()) + + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "some_state_type", + "state_key": "", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + state2 = list(self.get_success(context.get_current_state_ids()).values()) + + # Delete the chain cover info. + + def _delete_tables(txn): + txn.execute("DELETE FROM event_auth_chains") + txn.execute("DELETE FROM event_auth_chain_links") + + self.get_success(store.db_pool.runInteraction("test", _delete_tables)) + + # Insert and run the background update. + self.get_success( + store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + store.db_pool.updates._all_done = False + + while not self.get_success( + store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + store.db_pool.runInteraction( + "test", + store._get_auth_chain_difference_using_cover_index_txn, + room_id, + [state1, state2], + ) + ) -- cgit 1.5.1 From 1a08e0cdab0b3475fd4189aa1e3b6f9aaa823ccf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jan 2021 18:57:32 +0000 Subject: Fix event chain bg update. (#9118) We passed in a graph to `sorted_topologically` which didn't have an entry for each node (as we dropped nodes with no edges). --- changelog.d/9118.misc | 1 + synapse/util/iterutils.py | 2 +- tests/util/test_itertools.py | 8 ++++++++ 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9118.misc (limited to 'tests') diff --git a/changelog.d/9118.misc b/changelog.d/9118.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9118.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index f7b4857a84..6ef2b008a4 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -92,7 +92,7 @@ def sorted_topologically( node = heapq.heappop(zero_degree) yield node - for edge in reverse_graph[node]: + for edge in reverse_graph.get(node, []): if edge in degree_map: degree_map[edge] -= 1 if degree_map[edge] == 0: diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 1184cea5a3..522c8061f9 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -56,6 +56,14 @@ class SortTopologically(TestCase): graph = {} # type: Dict[int, List[int]] self.assertEqual(list(sorted_topologically([], graph)), []) + def test_handle_empty_graph(self): + "Test that a graph where a node doesn't have an entry is treated as empty" + + graph = {} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + def test_disconnected(self): "Test that a graph with no edges work" -- cgit 1.5.1 From 4575ad0b1e86c814e6d1c3ca6ac31ba4eeeb5c66 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 13:22:12 +0000 Subject: Store an IdP ID in the OIDC session (#9109) Again in preparation for handling more than one OIDC provider, add a new caveat to the macaroon used as an OIDC session cookie, which remembers which OIDC provider we are talking to. In future, when we get a callback, we'll need it to make sure we talk to the right IdP. As part of this, I'm adding an idp_id and idp_name field to the OIDC configuration object. They aren't yet documented, and we'll just use the old values by default. --- changelog.d/9109.feature | 1 + synapse/config/oidc_config.py | 26 +++++++++++++++++++++++--- synapse/handlers/oidc_handler.py | 22 ++++++++++++++++------ tests/handlers/test_oidc.py | 3 ++- 4 files changed, 42 insertions(+), 10 deletions(-) create mode 100644 changelog.d/9109.feature (limited to 'tests') diff --git a/changelog.d/9109.feature b/changelog.d/9109.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9109.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index c705de5694..fddca19223 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import string from typing import Optional, Type import attr @@ -38,7 +39,7 @@ class OIDCConfig(Config): oidc_config = config.get("oidc_config") if oidc_config and oidc_config.get("enabled", False): - validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config") + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) self.oidc_provider = _parse_oidc_config_dict(oidc_config) if not self.oidc_provider: @@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { + "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, + "idp_name": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": "methods: %s" % (", ".join(missing_methods),) ) + # MSC2858 will appy certain limits in what can be used as an IdP id, so let's + # enforce those limits now. + idp_id = oidc_config.get("idp_id", "oidc") + valid_idp_chars = set(string.ascii_letters + string.digits + "-._~") + + if any(c not in valid_idp_chars for c in idp_id): + raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"') + return OidcProviderConfig( + idp_id=idp_id, + idp_name=oidc_config.get("idp_name", "OIDC"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": ) -@attr.s +@attr.s(slots=True, frozen=True) class OidcProviderConfig: + # a unique identifier for this identity provider. Used in the 'user_external_ids' + # table, as well as the query/path parameter used in the login protocol. + idp_id = attr.ib(type=str) + + # user-facing name for this identity provider. + idp_name = attr.ib(type=str) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index d6347bb1b8..f63a90ec5c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -175,7 +175,7 @@ class OidcHandler: session_data = self._token_generator.verify_oidc_session_token( session, state ) - except MacaroonDeserializationException as e: + except (MacaroonDeserializationException, ValueError) as e: logger.exception("Invalid session") self._sso_handler.render_error(request, "invalid_session", str(e)) return @@ -253,10 +253,10 @@ class OidcProvider: self._server_name = hs.config.server_name # type: str # identifier for the external_ids table - self.idp_id = "oidc" + self.idp_id = provider.idp_id # user-facing name of this auth provider - self.idp_name = "OIDC" + self.idp_name = provider.idp_name self._sso_handler = hs.get_sso_handler() @@ -656,6 +656,7 @@ class OidcProvider: cookie = self._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( + idp_id=self.idp_id, nonce=nonce, client_redirect_url=client_redirect_url.decode(), ui_auth_session_id=ui_auth_session_id, @@ -924,6 +925,7 @@ class OidcSessionTokenGenerator: macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = session") macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,)) macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,)) macaroon.add_first_party_caveat( "client_redirect_url = %s" % (session_data.client_redirect_url,) @@ -952,6 +954,9 @@ class OidcSessionTokenGenerator: Returns: The data extracted from the session cookie + + Raises: + ValueError if an expected caveat is missing from the macaroon. """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -960,6 +965,7 @@ class OidcSessionTokenGenerator: v.satisfy_exact("type = session") v.satisfy_exact("state = %s" % (state,)) v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) # Sometimes there's a UI auth session ID, it seems to be OK to attempt # to always satisfy this. @@ -968,9 +974,9 @@ class OidcSessionTokenGenerator: v.verify(macaroon, self._macaroon_secret_key) - # Extract the `nonce`, `client_redirect_url`, and maybe the - # `ui_auth_session_id` from the token. + # Extract the session data from the token. nonce = self._get_value_from_macaroon(macaroon, "nonce") + idp_id = self._get_value_from_macaroon(macaroon, "idp_id") client_redirect_url = self._get_value_from_macaroon( macaroon, "client_redirect_url" ) @@ -983,6 +989,7 @@ class OidcSessionTokenGenerator: return OidcSessionData( nonce=nonce, + idp_id=idp_id, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, ) @@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator: The extracted value Raises: - Exception: if the caveat was not in the macaroon + ValueError: if the caveat was not in the macaroon """ prefix = key + " = " for caveat in macaroon.caveats: @@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator: class OidcSessionData: """The attributes which are stored in a OIDC session cookie""" + # the Identity Provider being used + idp_id = attr.ib(type=str) + # The `nonce` parameter passed to the OIDC provider. nonce = attr.ib(type=str) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5d338bea87..38ae8ca19e 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -848,6 +848,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( + idp_id="oidc", nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -990,7 +991,7 @@ async def _make_callback_with_userinfo( session = handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - nonce="nonce", client_redirect_url=client_redirect_url, + idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, ), ) request = _build_callback_request("code", state, session) -- cgit 1.5.1 From 0dd2649c127e4eb538dfbf0c879bd66c9ff1599c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 13:45:13 +0000 Subject: Improve UsernamePickerTestCase (#9112) * make the OIDC bits of the test work at a higher level - via the REST api instead of poking the OIDCHandler directly. * Move it to test_login.py, where I think it fits better. --- changelog.d/9112.misc | 1 + tests/handlers/test_oidc.py | 120 +------------------------------- tests/rest/client/v1/test_login.py | 105 +++++++++++++++++++++++++++- tests/rest/client/v1/utils.py | 11 +-- tests/rest/client/v2_alpha/test_auth.py | 2 +- 5 files changed, 114 insertions(+), 125 deletions(-) create mode 100644 changelog.d/9112.misc (limited to 'tests') diff --git a/changelog.d/9112.misc b/changelog.d/9112.misc new file mode 100644 index 0000000000..691f9d8b43 --- /dev/null +++ b/changelog.d/9112.misc @@ -0,0 +1 @@ +Improve `UsernamePickerTestCase`. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 38ae8ca19e..02e21ed6ca 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,20 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import re -from typing import Dict, Optional -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Optional +from urllib.parse import parse_qs, urlparse from mock import ANY, Mock, patch import pymacaroons -from twisted.web.resource import Resource - -from synapse.api.errors import RedirectException from synapse.handlers.sso import MappingException -from synapse.rest.client.v1 import login -from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.server import HomeServer from synapse.types import UserID @@ -856,116 +850,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -class UsernamePickerTestCase(HomeserverTestCase): - if not HAS_OIDC: - skip = "requires OIDC" - - servlets = [login.register_servlets] - - def default_config(self): - config = super().default_config() - config["public_baseurl"] = BASE_URL - oidc_config = { - "enabled": True, - "client_id": CLIENT_ID, - "client_secret": CLIENT_SECRET, - "issuer": ISSUER, - "scopes": SCOPES, - "user_mapping_provider": { - "config": {"display_name_template": "{{ user.displayname }}"} - }, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - oidc_config.update(config.get("oidc_config", {})) - config["oidc_config"] = oidc_config - - # whitelist this client URI so we redirect straight to it rather than - # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - return d - - def test_username_picker(self): - """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" - - # first of all, mock up an OIDC callback to the OidcHandler, which should - # raise a RedirectException - userinfo = {"sub": "tester", "displayname": "Jonny"} - f = self.get_failure( - _make_callback_with_userinfo( - self.hs, userinfo, client_redirect_url=client_redirect_url - ), - RedirectException, - ) - - # check the Location and cookies returned by the RedirectException - self.assertEqual(f.value.location, b"/_synapse/client/pick_username") - cookieheader = f.value.cookies[0] - regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") - m = regex.search(cookieheader) - if not m: - self.fail("cookie header %s does not match %s" % (cookieheader, regex)) - - # introspect the sso handler a bit to check that the username mapping session - # looks ok. - session_id = m.group(1).decode("ascii") - username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions - self.assertIn( - session_id, username_mapping_sessions, "session id not found in map" - ) - session = username_mapping_sessions[session_id] - self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) - - # the expiry time should be about 15 minutes away - expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) - self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - - # Now, submit a username to the username picker, which should serve a redirect - # back to the client - submit_path = f.value.location + b"/submit" - content = urlencode({b"username": b"bobby"}).encode("utf8") - chan = self.make_request( - "POST", - path=submit_path, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", cookieheader), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url - ) - - # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") - - async def _make_callback_with_userinfo( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" ) -> None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index f9b8011961..73a009efd1 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -17,6 +17,7 @@ import time import urllib.parse from html.parser import HTMLParser from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from urllib.parse import parse_qs, urlencode, urlparse from mock import Mock @@ -30,13 +31,14 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.types import create_requester from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG -from tests.unittest import override_config, skip_unless +from tests.unittest import HomeserverTestCase, override_config, skip_unless try: import jwt @@ -1060,3 +1062,104 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"401", channel.result) + + +@skip_unless(HAS_OIDC, "requires OIDC") +class UsernamePickerTestCase(HomeserverTestCase): + """Tests for the username picker flow of SSO login""" + + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + + config["oidc_config"] = {} + config["oidc_config"].update(TEST_OIDC_CONFIG) + config["oidc_config"]["user_mapping_provider"] = { + "config": {"display_name_template": "{{ user.displayname }}"} + } + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # do the start of the login flow + channel = self.helper.auth_via_oidc( + {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + ) + + # that should redirect to the username picker + self.assertEqual(channel.code, 302, channel.result) + picker_url = channel.headers.getRawHeaders("Location")[0] + self.assertEqual(picker_url, "/_synapse/client/pick_username") + + # ... with a username_mapping_session cookie + cookies = {} # type: Dict[str,str] + channel.extract_cookies(cookies) + self.assertIn("username_mapping_session", cookies) + session_id = cookies["username_mapping_session"] + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map", + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = picker_url + "/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 85d1709ead..c6647dbe08 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -363,10 +363,10 @@ class RestHelper: the normal places. """ client_redirect_url = "https://x" - channel = self.auth_via_oidc(remote_user_id, client_redirect_url) + channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) # expect a confirmation page - assert channel.code == 200 + assert channel.code == 200, channel.result # fish the matrix login token out of the body of the confirmation page m = re.search( @@ -390,7 +390,7 @@ class RestHelper: def auth_via_oidc( self, - remote_user_id: str, + user_info_dict: JsonDict, client_redirect_url: Optional[str] = None, ui_auth_session_id: Optional[str] = None, ) -> FakeChannel: @@ -411,7 +411,8 @@ class RestHelper: the normal places. Args: - remote_user_id: the remote id that the OIDC provider should present + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. client_redirect_url: for a login flow, the client redirect URL to pass to the login redirect endpoint ui_auth_session_id: if set, we will perform a UI Auth flow. The session id @@ -457,7 +458,7 @@ class RestHelper: # a dummy OIDC access token ("https://issuer.test/token", {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", {"sub": remote_user_id}), + ("https://issuer.test/userinfo", user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 50630106ad..3e8661f9b9 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -411,7 +411,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # run the UIA-via-SSO flow session_id = channel.json_body["session"] channel = self.helper.auth_via_oidc( - remote_user_id=remote_user_id, ui_auth_session_id=session_id + {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page -- cgit 1.5.1 From 74dd90604189a0310c7b2f7eed0e6b2ac26d04f1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 11:00:13 -0500 Subject: Avoid raising the body exceeded error multiple times. (#9108) Previously this code generated unreferenced `Deferred` instances which caused "Unhandled Deferreds" errors to appear in error situations. --- changelog.d/9108.bugfix | 1 + synapse/http/client.py | 12 ++- .../federation/test_matrix_federation_agent.py | 4 +- tests/http/test_client.py | 101 +++++++++++++++++++++ 4 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9108.bugfix create mode 100644 tests/http/test_client.py (limited to 'tests') diff --git a/changelog.d/9108.bugfix b/changelog.d/9108.bugfix new file mode 100644 index 0000000000..465ef63508 --- /dev/null +++ b/changelog.d/9108.bugfix @@ -0,0 +1 @@ +Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. diff --git a/synapse/http/client.py b/synapse/http/client.py index dc4b81ca60..df498c8645 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -766,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): self.max_size = max_size def dataReceived(self, data: bytes) -> None: + # If the deferred was called, bail early. + if self.deferred.called: + return + self.stream.write(data) self.length += len(data) + # The first time the maximum size is exceeded, error and cancel the + # connection. dataReceived might be called again if data was received + # in the meantime. if self.max_size is not None and self.length >= self.max_size: self.deferred.errback(BodyExceededMaxSize()) - self.deferred = defer.Deferred() self.transport.loseConnection() def connectionLost(self, reason: Failure) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 4e51839d0f..686012dd25 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -1095,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # Expire both caches and repeat the request self.reactor.pump((10000.0,)) - # Repated the request, this time it should fail if the lookup fails. + # Repeat the request, this time it should fail if the lookup fails. fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv") ) @@ -1130,7 +1130,7 @@ class MatrixFederationAgentTests(unittest.TestCase): content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', ) - # The result is sucessful, but disabled delegation. + # The result is successful, but disabled delegation. r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) diff --git a/tests/http/test_client.py b/tests/http/test_client.py new file mode 100644 index 0000000000..f17c122e93 --- /dev/null +++ b/tests/http/test_client.py @@ -0,0 +1,101 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO + +from mock import Mock + +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone + +from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size + +from tests.unittest import TestCase + + +class ReadBodyWithMaxSizeTests(TestCase): + def setUp(self): + """Start reading the body, returns the response, result and proto""" + self.response = Mock() + self.result = BytesIO() + self.deferred = read_body_with_max_size(self.response, self.result, 6) + + # Fish the protocol out of the response. + self.protocol = self.response.deliverBody.call_args[0][0] + self.protocol.transport = Mock() + + def _cleanup_error(self): + """Ensure that the error in the Deferred is handled gracefully.""" + called = [False] + + def errback(f): + called[0] = True + + self.deferred.addErrback(errback) + self.assertTrue(called[0]) + + def test_no_error(self): + """A response that is NOT too large.""" + + # Start sending data. + self.protocol.dataReceived(b"12345") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"12345") + self.assertEqual(self.deferred.result, 5) + + def test_too_large(self): + """A response which is too large raises an exception.""" + + # Start sending data. + self.protocol.dataReceived(b"1234567890") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self._cleanup_error() + + def test_multiple_packets(self): + """Data should be accummulated through mutliple packets.""" + + # Start sending data. + self.protocol.dataReceived(b"12") + self.protocol.dataReceived(b"34") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234") + self.assertEqual(self.deferred.result, 4) + + def test_additional_data(self): + """A connection can receive data after being closed.""" + + # Start sending data. + self.protocol.dataReceived(b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self.protocol.transport.loseConnection.assert_called_once() + + # More data might have come in. + self.protocol.dataReceived(b"1234567890") + # Close the connection. + self.protocol.connectionLost(Failure(ResponseDone())) + + self.assertEqual(self.result.getvalue(), b"1234567890") + self.assertIsInstance(self.deferred.result, Failure) + self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) + self._cleanup_error() -- cgit 1.5.1 From 3e4cdfe5d9b6152b9a0b7f7ad22623c2bf19f7ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Jan 2021 11:18:09 -0500 Subject: Add an admin API endpoint to protect media. (#9086) Protecting media stops it from being quarantined when e.g. all media in a room is quarantined. This is useful for sticker packs and other media that is uploaded by server administrators, but used by many people. --- changelog.d/9086.feature | 1 + docs/admin_api/media_admin_api.md | 24 +++++++++++++++ synapse/rest/admin/media.py | 64 ++++++++++++++++++++++++++++++--------- tests/rest/admin/test_admin.py | 8 +++-- 4 files changed, 79 insertions(+), 18 deletions(-) create mode 100644 changelog.d/9086.feature (limited to 'tests') diff --git a/changelog.d/9086.feature b/changelog.d/9086.feature new file mode 100644 index 0000000000..3e678e24d5 --- /dev/null +++ b/changelog.d/9086.feature @@ -0,0 +1 @@ +Add an admin API for protecting local media from quarantine. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index dfb8c5d751..90faeaaef0 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [Quarantining media by ID](#quarantining-media-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) + * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) @@ -123,6 +124,29 @@ The following fields are returned in the JSON response body: * `num_quarantined`: integer - The number of media items successfully quarantined +## Protecting media from being quarantined + +This API protects a single piece of local media from being quarantined using the +above APIs. This is useful for sticker packs and other shared media which you do +not want to get quarantined, especially when +[quarantining media in a room](#quarantining-media-in-a-room). + +Request: + +``` +POST /_synapse/admin/v1/media/protect/ + +{} +``` + +Where `media_id` is in the form of `abcdefg12345...`. + +Response: + +```json +{} +``` + # Delete local media This API deletes the *local* media from the disk of your own server. This includes any local thumbnails and copies of media downloaded from diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c82b4f87d6..8720b1401f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,6 +15,9 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -23,6 +26,10 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet): admin_patterns("/quarantine_media/(?P[^/]+)") ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id: str): + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id: str): + async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet): "/media/quarantine/(?P[^/]+)/(?P[^/]+)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, server_name: str, media_id: str): + async def on_POST( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class ProtectMediaByID(RestServlet): + """Protect local media from being quarantined. + """ + + PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Protecting local media by ID: %s", media_id) + + # Quarantine this media id + await self.store.mark_local_media_as_safe(media_id) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ PATTERNS = admin_patterns("/room/(?P[^/]+)/media") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): PATTERNS = admin_patterns("/purge_media_cache") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_DELETE(self, request, server_name: str, media_id: str): + async def on_DELETE( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: @@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P[^/]+)/delete") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request, server_name: str): + async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} -def register_servlets_for_media_repo(hs, http_server): +def register_servlets_for_media_repo(hs: "HomeServer", http_server): """ Media repo specific APIs. """ @@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) + ProtectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 586b877bda..9d22c04073 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -153,8 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - # Allow for uploading and downloading to/from the media repo self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] @@ -428,7 +426,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Mark the second item as safe from quarantine. _, media_id_2 = server_and_media_id_2.split("/") - self.get_success(self.store.mark_local_media_as_safe(media_id_2)) + # Quarantine the media + url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) + channel = self.make_request("POST", url, access_token=admin_user_tok) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( -- cgit 1.5.1 From 9de6b9411750c9adf72bdd9d180d2f51b89e3c03 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 16:55:29 +0000 Subject: Land support for multiple OIDC providers (#9110) This is the final step for supporting multiple OIDC providers concurrently. First of all, we reorganise the config so that you can specify a list of OIDC providers, instead of a single one. Before: oidc_config: enabled: true issuer: "https://oidc_provider" # etc After: oidc_providers: - idp_id: prov1 issuer: "https://oidc_provider" - idp_id: prov2 issuer: "https://another_oidc_provider" The old format is still grandfathered in. With that done, it's then simply a matter of having OidcHandler instantiate a new OidcProvider for each configured provider. --- changelog.d/9110.feature | 1 + docs/openid.md | 201 ++++++++++++------------ docs/sample_config.yaml | 274 ++++++++++++++++---------------- synapse/config/cas.py | 2 +- synapse/config/oidc_config.py | 329 ++++++++++++++++++++++----------------- synapse/handlers/oidc_handler.py | 27 +++- tests/handlers/test_oidc.py | 4 +- 7 files changed, 456 insertions(+), 382 deletions(-) create mode 100644 changelog.d/9110.feature (limited to 'tests') diff --git a/changelog.d/9110.feature b/changelog.d/9110.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9110.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/openid.md b/docs/openid.md index ffa4238fff..b86ae89768 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -42,11 +42,10 @@ as follows: * For other installation mechanisms, see the documentation provided by the maintainer. -To enable the OpenID integration, you should then add an `oidc_config` section -to your configuration file (or uncomment the `enabled: true` line in the -existing section). See [sample_config.yaml](./sample_config.yaml) for some -sample settings, as well as the text below for example configurations for -specific providers. +To enable the OpenID integration, you should then add a section to the `oidc_providers` +setting in your configuration file (or uncomment one of the existing examples). +See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as +the text below for example configurations for specific providers. ## Sample configs @@ -62,20 +61,21 @@ Directory (tenant) ID as it will be used in the Azure links. Edit your Synapse config file and change the `oidc_config` section: ```yaml -oidc_config: - enabled: true - issuer: "https://login.microsoftonline.com//v2.0" - client_id: "" - client_secret: "" - scopes: ["openid", "profile"] - authorization_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/authorize" - token_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/token" - userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" - - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username.split('@')[0] }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: microsoft + idp_name: Microsoft + issuer: "https://login.microsoftonline.com//v2.0" + client_id: "" + client_secret: "" + scopes: ["openid", "profile"] + authorization_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/authorize" + token_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/token" + userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" + + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username.split('@')[0] }}" + display_name_template: "{{ user.name }}" ``` ### [Dex][dex-idp] @@ -103,17 +103,18 @@ Run with `dex serve examples/config-dev.yaml`. Synapse config: ```yaml -oidc_config: - enabled: true - skip_verification: true # This is needed as Dex is served on an insecure endpoint - issuer: "http://127.0.0.1:5556/dex" - client_id: "synapse" - client_secret: "secret" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.name }}" - display_name_template: "{{ user.name|capitalize }}" +oidc_providers: + - idp_id: dex + idp_name: "My Dex server" + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + client_id: "synapse" + client_secret: "secret" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.name }}" + display_name_template: "{{ user.name|capitalize }}" ``` ### [Keycloak][keycloak-idp] @@ -152,16 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to 8. Copy Secret ```yaml -oidc_config: - enabled: true - issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" - client_id: "synapse" - client_secret: "copy secret generated from above" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: keycloak + idp_name: "My KeyCloak server" + issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" + client_id: "synapse" + client_secret: "copy secret generated from above" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### [Auth0][auth0] @@ -191,16 +193,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: auth0 + idp_name: Auth0 + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitHub @@ -219,21 +222,22 @@ does not return a `sub` property, an alternative `subject_claim` has to be set. Synapse config: ```yaml -oidc_config: - enabled: true - discover: false - issuer: "https://github.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - authorization_endpoint: "https://github.com/login/oauth/authorize" - token_endpoint: "https://github.com/login/oauth/access_token" - userinfo_endpoint: "https://api.github.com/user" - scopes: ["read:user"] - user_mapping_provider: - config: - subject_claim: "id" - localpart_template: "{{ user.login }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: github + idp_name: Github + discover: false + issuer: "https://github.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: ["read:user"] + user_mapping_provider: + config: + subject_claim: "id" + localpart_template: "{{ user.login }}" + display_name_template: "{{ user.name }}" ``` ### [Google][google-idp] @@ -243,16 +247,17 @@ oidc_config: 2. add an "OAuth Client ID" for a Web Application under "Credentials". 3. Copy the Client ID and Client Secret, and add the following to your synapse config: ```yaml - oidc_config: - enabled: true - issuer: "https://accounts.google.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.given_name|lower }}" - display_name_template: "{{ user.name }}" + oidc_providers: + - idp_id: google + idp_name: Google + issuer: "https://accounts.google.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.given_name|lower }}" + display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse public baseurl]/_synapse/oidc/callback`. @@ -266,16 +271,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://id.twitch.tv/oauth2/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: twitch + idp_name: Twitch + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitLab @@ -287,16 +293,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://gitlab.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - scopes: ["openid", "read_user"] - user_profile_method: "userinfo_endpoint" - user_mapping_provider: - config: - localpart_template: '{{ user.nickname }}' - display_name_template: '{{ user.name }}' +oidc_providers: + - idp_id: gitlab + idp_name: Gitlab + issuer: "https://gitlab.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: ["openid", "read_user"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: '{{ user.nickname }}' + display_name_template: '{{ user.name }}' ``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9da351f9f3..ae995efe9b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1709,141 +1709,149 @@ saml2_config: #idp_entityid: 'https://our_idp/entityid' -# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. +# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration +# and login. # -# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md -# for some example configurations. +# Options for each entry include: # -oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. - # - #enabled: true - - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. - # - #discover: false - - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. - # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". - # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. - # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{ user.preferred_username }}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{ user.given_name }} {{ user.last_name }}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{ user.birthdate }}" - +# idp_id: a unique identifier for this identity provider. Used internally +# by Synapse; should be a single word such as 'github'. +# +# Note that, if this is changed, users authenticating via that provider +# will no longer be recognised as the same user! +# +# idp_name: A user-facing name for this identity provider, which is used to +# offer the user a choice of login mechanisms. +# +# discover: set to 'false' to disable the use of the OIDC discovery mechanism +# to discover endpoints. Defaults to true. +# +# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery +# is enabled) to discover the provider's endpoints. +# +# client_id: Required. oauth2 client id to use. +# +# client_secret: Required. oauth2 client secret to use. +# +# client_auth_method: auth method to use when exchanging the token. Valid +# values are 'client_secret_basic' (default), 'client_secret_post' and +# 'none'. +# +# scopes: list of scopes to request. This should normally include the "openid" +# scope. Defaults to ["openid"]. +# +# authorization_endpoint: the oauth2 authorization endpoint. Required if +# provider discovery is disabled. +# +# token_endpoint: the oauth2 token endpoint. Required if provider discovery is +# disabled. +# +# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is +# disabled and the 'openid' scope is not requested. +# +# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and +# the 'openid' scope is used. +# +# skip_verification: set to 'true' to skip metadata verification. Use this if +# you are connecting to a provider that is not OpenID Connect compliant. +# Defaults to false. Avoid this in production. +# +# user_profile_method: Whether to fetch the user profile from the userinfo +# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. +# +# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is +# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the +# userinfo endpoint. +# +# allow_existing_users: set to 'true' to allow a user logging in via OIDC to +# match a pre-existing account instead of failing. This could be used if +# switching from password logins to OIDC. Defaults to false. +# +# user_mapping_provider: Configuration for how attributes returned from a OIDC +# provider are mapped onto a matrix user. This setting has the following +# sub-properties: +# +# module: The class name of a custom mapping module. Default is +# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. +# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers +# for information on implementing a custom mapping provider. +# +# config: Configuration for the mapping provider module. This section will +# be passed as a Python dictionary to the user mapping provider +# module's `parse_config` method. +# +# For the default provider, the following settings are available: +# +# sub: name of the claim containing a unique identifier for the +# user. Defaults to 'sub', which OpenID Connect compliant +# providers should provide. +# +# localpart_template: Jinja2 template for the localpart of the MXID. +# If this is not set, the user will be prompted to choose their +# own username. +# +# display_name_template: Jinja2 template for the display name to set +# on first login. If unset, no displayname will be set. +# +# extra_attributes: a map of Jinja2 templates for extra attributes +# to send back to the client during login. +# Note that these are non-standard and clients will ignore them +# without modifications. +# +# When rendering, the Jinja2 templates are given a 'user' variable, +# which is set to the claims returned by the UserInfo Endpoint and/or +# in the ID Token. +# +# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md +# for information on how to configure these options. +# +# For backwards compatibility, it is also possible to configure a single OIDC +# provider via an 'oidc_config' setting. This is now deprecated and admins are +# advised to migrate to the 'oidc_providers' format. +# +oidc_providers: + # Generic example + # + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true + + # For use with Keycloak + # + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] + + # For use with Github + # + #- idp_id: google + # idp_name: Google + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{ user.login }" + # display_name_template: "{ user.name }" # Enable Central Authentication Service (CAS) for registration and login. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 2f97e6d258..c7877b4095 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -40,7 +40,7 @@ class CasConfig(Config): self.cas_required_attributes = {} def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """ + return """\ # Enable Central Authentication Service (CAS) for registration and login. # cas_config: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index fddca19223..c7fa749377 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,7 +15,7 @@ # limitations under the License. import string -from typing import Optional, Type +from typing import Iterable, Optional, Type import attr @@ -33,16 +33,8 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): - validate_config(MAIN_CONFIG_SCHEMA, config, ()) - - self.oidc_provider = None # type: Optional[OidcProviderConfig] - - oidc_config = config.get("oidc_config") - if oidc_config and oidc_config.get("enabled", False): - validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) - self.oidc_provider = _parse_oidc_config_dict(oidc_config) - - if not self.oidc_provider: + self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) + if not self.oidc_providers: return try: @@ -58,144 +50,153 @@ class OIDCConfig(Config): @property def oidc_enabled(self) -> bool: # OIDC is enabled if we have a provider - return bool(self.oidc_provider) + return bool(self.oidc_providers) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. + # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration + # and login. + # + # Options for each entry include: + # + # idp_id: a unique identifier for this identity provider. Used internally + # by Synapse; should be a single word such as 'github'. + # + # Note that, if this is changed, users authenticating via that provider + # will no longer be recognised as the same user! + # + # idp_name: A user-facing name for this identity provider, which is used to + # offer the user a choice of login mechanisms. + # + # discover: set to 'false' to disable the use of the OIDC discovery mechanism + # to discover endpoints. Defaults to true. + # + # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery + # is enabled) to discover the provider's endpoints. + # + # client_id: Required. oauth2 client id to use. + # + # client_secret: Required. oauth2 client secret to use. + # + # client_auth_method: auth method to use when exchanging the token. Valid + # values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + # scopes: list of scopes to request. This should normally include the "openid" + # scope. Defaults to ["openid"]. + # + # authorization_endpoint: the oauth2 authorization endpoint. Required if + # provider discovery is disabled. + # + # token_endpoint: the oauth2 token endpoint. Required if provider discovery is + # disabled. + # + # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is + # disabled and the 'openid' scope is not requested. + # + # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and + # the 'openid' scope is used. + # + # skip_verification: set to 'true' to skip metadata verification. Use this if + # you are connecting to a provider that is not OpenID Connect compliant. + # Defaults to false. Avoid this in production. + # + # user_profile_method: Whether to fetch the user profile from the userinfo + # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. + # + # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is + # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the + # userinfo endpoint. + # + # allow_existing_users: set to 'true' to allow a user logging in via OIDC to + # match a pre-existing account instead of failing. This could be used if + # switching from password logins to OIDC. Defaults to false. + # + # user_mapping_provider: Configuration for how attributes returned from a OIDC + # provider are mapped onto a matrix user. This setting has the following + # sub-properties: + # + # module: The class name of a custom mapping module. Default is + # {mapping_provider!r}. + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. + # + # config: Configuration for the mapping provider module. This section will + # be passed as a Python dictionary to the user mapping provider + # module's `parse_config` method. + # + # For the default provider, the following settings are available: + # + # sub: name of the claim containing a unique identifier for the + # user. Defaults to 'sub', which OpenID Connect compliant + # providers should provide. + # + # localpart_template: Jinja2 template for the localpart of the MXID. + # If this is not set, the user will be prompted to choose their + # own username. + # + # display_name_template: Jinja2 template for the display name to set + # on first login. If unset, no displayname will be set. + # + # extra_attributes: a map of Jinja2 templates for extra attributes + # to send back to the client during login. + # Note that these are non-standard and clients will ignore them + # without modifications. + # + # When rendering, the Jinja2 templates are given a 'user' variable, + # which is set to the claims returned by the UserInfo Endpoint and/or + # in the ID Token. # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md - # for some example configurations. + # for information on how to configure these options. # - oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. - # - #enabled: true - - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. - # - #discover: false - - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. - # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". + # For backwards compatibility, it is also possible to configure a single OIDC + # provider via an 'oidc_config' setting. This is now deprecated and admins are + # advised to migrate to the 'oidc_providers' format. + # + oidc_providers: + # Generic example # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true + + # For use with Keycloak # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] + + # For use with Github # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is {mapping_provider!r}. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{{{ user.preferred_username }}}}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{{{ user.birthdate }}}}" + #- idp_id: google + # idp_name: Google + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) @@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { }, } -# the `oidc_config` setting can either be None (as it is in the default +# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name +OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = { + "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}] +} + + +# the `oidc_providers` list can either be None (as it is in the default config), or +# a list of provider configs, each of which requires an explicit ID and name. +OIDC_PROVIDER_LIST_SCHEMA = { + "oneOf": [ + {"type": "null"}, + {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA}, + ] +} + +# the `oidc_config` setting can either be None (which it used to be in the default # config), or an object. If an object, it is ignored unless it has an "enabled: True" # property. # @@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { # additional checks in the code. OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} +# the top-level schema can contain an "oidc_config" and/or an "oidc_providers". MAIN_CONFIG_SCHEMA = { "type": "object", - "properties": {"oidc_config": OIDC_CONFIG_SCHEMA}, + "properties": { + "oidc_config": OIDC_CONFIG_SCHEMA, + "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA, + }, } +def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]: + """extract and parse the OIDC provider configs from the config dict + + The configuration may contain either a single `oidc_config` object with an + `enabled: True` property, or a list of provider configurations under + `oidc_providers`, *or both*. + + Returns a generator which yields the OidcProviderConfig objects + """ + validate_config(MAIN_CONFIG_SCHEMA, config, ()) + + for p in config.get("oidc_providers") or []: + yield _parse_oidc_config_dict(p) + + # for backwards-compatibility, it is also possible to provide a single "oidc_config" + # object with an "enabled: True" property. + oidc_config = config.get("oidc_config") + if oidc_config and oidc_config.get("enabled", False): + # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that + # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA + # above), so now we need to validate it. + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) + yield _parse_oidc_config_dict(oidc_config) + + def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": """Take the configuration dict and parse it into an OidcProviderConfig diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f63a90ec5c..5e5fda7b2f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -78,21 +78,28 @@ class OidcHandler: def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() - provider_conf = hs.config.oidc.oidc_provider + provider_confs = hs.config.oidc.oidc_providers # we should not have been instantiated if there is no configured provider. - assert provider_conf is not None + assert provider_confs self._token_generator = OidcSessionTokenGenerator(hs) - - self._provider = OidcProvider(hs, self._token_generator, provider_conf) + self._providers = { + p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs + } async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. Called at startup to ensure we have everything we need. """ - await self._provider.load_metadata() - await self._provider.load_jwks() + for idp_id, p in self._providers.items(): + try: + await p.load_metadata() + await p.load_jwks() + except Exception as e: + raise Exception( + "Error while initialising OIDC provider %r" % (idp_id,) + ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: """Handle an incoming request to /_synapse/oidc/callback @@ -184,6 +191,12 @@ class OidcHandler: self._sso_handler.render_error(request, "mismatching_session", str(e)) return + oidc_provider = self._providers.get(session_data.idp_id) + if not oidc_provider: + logger.error("OIDC session uses unknown IdP %r", oidc_provider) + self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP") + return + if b"code" not in request.args: logger.info("Code parameter is missing") self._sso_handler.render_error( @@ -193,7 +206,7 @@ class OidcHandler: code = request.args[b"code"][0].decode() - await self._provider.handle_oidc_callback(request, session_data, code) + await oidc_provider.handle_oidc_callback(request, session_data, code) class OidcError(Exception): diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 02e21ed6ca..b3dfa40d25 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -145,7 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() - self.provider = self.handler._provider + self.provider = self.handler._providers["oidc"] sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -866,7 +866,7 @@ async def _make_callback_with_userinfo( from synapse.handlers.oidc_handler import OidcSessionData handler = hs.get_oidc_handler() - provider = handler._provider + provider = handler._providers["oidc"] provider._exchange_code = simple_async_mock(return_value={}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) -- cgit 1.5.1 From 350d9923cd1d35885f8f8e9c6036caec5eebfa9f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 Jan 2021 17:18:37 +0000 Subject: Make chain cover index bg update go faster (#9124) We do this by allowing a single iteration to process multiple rooms at a time, as there are often a lot of really tiny rooms, which can massively slow things down. --- changelog.d/9124.misc | 1 + .../storage/databases/main/events_bg_updates.py | 329 +++++++++++---------- tests/storage/test_event_chain.py | 217 ++++++++++++-- 3 files changed, 366 insertions(+), 181 deletions(-) create mode 100644 changelog.d/9124.misc (limited to 'tests') diff --git a/changelog.d/9124.misc b/changelog.d/9124.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9124.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 7128dc1742..e46e44ba54 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -16,6 +16,8 @@ import logging from typing import Dict, List, Optional, Tuple +import attr + from synapse.api.constants import EventContentFields from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -28,6 +30,25 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True) +class _CalculateChainCover: + """Return value for _calculate_chain_cover_txn. + """ + + # The last room_id/depth/stream processed. + room_id = attr.ib(type=str) + depth = attr.ib(type=int) + stream = attr.ib(type=int) + + # Number of rows processed + processed_count = attr.ib(type=int) + + # Map from room_id to last depth/stream processed for each room that we have + # processed all events for (i.e. the rooms we can flip the + # `has_auth_chain_index` for) + finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) + + class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" @@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): current_room_id = progress.get("current_room_id", "") - # Have we finished processing the current room. - finished = progress.get("finished", True) - # Where we've processed up to in the room, defaults to the start of the # room. last_depth = progress.get("last_depth", -1) last_stream = progress.get("last_stream", -1) - # Have we set the `has_auth_chain_index` for the room yet. - has_set_room_has_chain_index = progress.get( - "has_set_room_has_chain_index", False + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + current_room_id, + last_depth, + last_stream, + batch_size, + single_room=False, ) - if finished: - # If we've finished with the previous room (or its our first - # iteration) we move on to the next room. - - def _get_next_room(txn: Cursor) -> Optional[str]: - sql = """ - SELECT room_id FROM rooms - WHERE room_id > ? - AND ( - NOT has_auth_chain_index - OR has_auth_chain_index IS NULL - ) - ORDER BY room_id - LIMIT 1 - """ - txn.execute(sql, (current_room_id,)) - row = txn.fetchone() - if row: - return row[0] + finished = result.processed_count == 0 - return None - - current_room_id = await self.db_pool.runInteraction( - "_chain_cover_index", _get_next_room - ) - if not current_room_id: - await self.db_pool.updates._end_background_update("chain_cover") - return 0 - - logger.debug("Adding chain cover to %s", current_room_id) - - def _calculate_auth_chain( - txn: Cursor, last_depth: int, last_stream: int - ) -> Tuple[int, int, int]: - # Get the next set of events in the room (that we haven't already - # computed chain cover for). We do this in topological order. - - # We want to do a `(topological_ordering, stream_ordering) > (?,?)` - # comparison, but that is not supported on older SQLite versions - tuple_clause, tuple_args = make_tuple_comparison_clause( - self.database_engine, - [ - ("topological_ordering", last_depth), - ("stream_ordering", last_stream), - ], - ) + total_rows_processed = result.processed_count + current_room_id = result.room_id + last_depth = result.depth + last_stream = result.stream - sql = """ - SELECT - event_id, state_events.type, state_events.state_key, - topological_ordering, stream_ordering - FROM events - INNER JOIN state_events USING (event_id) - LEFT JOIN event_auth_chains USING (event_id) - LEFT JOIN event_auth_chain_to_calculate USING (event_id) - WHERE events.room_id = ? - AND event_auth_chains.event_id IS NULL - AND event_auth_chain_to_calculate.event_id IS NULL - AND %(tuple_cmp)s - ORDER BY topological_ordering, stream_ordering - LIMIT ? - """ % { - "tuple_cmp": tuple_clause, - } - - args = [current_room_id] - args.extend(tuple_args) - args.append(batch_size) - - txn.execute(sql, args) - rows = txn.fetchall() - - # Put the results in the necessary format for - # `_add_chain_cover_index` - event_to_room_id = {row[0]: current_room_id for row in rows} - event_to_types = {row[0]: (row[1], row[2]) for row in rows} - - new_last_depth = rows[-1][3] if rows else last_depth # type: int - new_last_stream = rows[-1][4] if rows else last_stream # type: int - - count = len(rows) - - # We also need to fetch the auth events for them. - auth_events = self.db_pool.simple_select_many_txn( - txn, - table="event_auth", - column="event_id", - iterable=event_to_room_id, - keyvalues={}, - retcols=("event_id", "auth_id"), - ) - - event_to_auth_chain = {} # type: Dict[str, List[str]] - for row in auth_events: - event_to_auth_chain.setdefault(row["event_id"], []).append( - row["auth_id"] - ) - - # Calculate and persist the chain cover index for this set of events. - # - # Annoyingly we need to gut wrench into the persit event store so that - # we can reuse the function to calculate the chain cover for rooms. - PersistEventsStore._add_chain_cover_index( - txn, - self.db_pool, - event_to_room_id, - event_to_types, - event_to_auth_chain, - ) - - return new_last_depth, new_last_stream, count - - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream - ) - - total_rows_processed = count - - if count < batch_size and not has_set_room_has_chain_index: + for room_id, (depth, stream) in result.finished_room_map.items(): # If we've done all the events in the room we flip the # `has_auth_chain_index` in the DB. Note that its possible for # further events to be persisted between the above and setting the @@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.simple_update( table="rooms", - keyvalues={"room_id": current_room_id}, + keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": True}, desc="_chain_cover_index", ) - has_set_room_has_chain_index = True # Handle any events that might have raced with us flipping the # bit above. - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + room_id, + depth, + stream, + batch_size=None, + single_room=True, ) - total_rows_processed += count + total_rows_processed += result.processed_count - # Note that at this point its technically possible that more events - # than our `batch_size` have been persisted without their chain - # cover, so we need to continue processing this room if the last - # count returned was equal to the `batch_size`. + if finished: + await self.db_pool.updates._end_background_update("chain_cover") + return total_rows_processed - if count < batch_size: - # We've finished calculating the index for this room, move on to the - # next room. - await self.db_pool.updates._background_update_progress( - "chain_cover", {"current_room_id": current_room_id, "finished": True}, - ) - else: - # We still have outstanding events to calculate the index for. - await self.db_pool.updates._background_update_progress( - "chain_cover", - { - "current_room_id": current_room_id, - "last_depth": last_depth, - "last_stream": last_stream, - "has_auth_chain_index": has_set_room_has_chain_index, - "finished": False, - }, - ) + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + }, + ) return total_rows_processed + + def _calculate_chain_cover_txn( + self, + txn: Cursor, + last_room_id: str, + last_depth: int, + last_stream: int, + batch_size: Optional[int], + single_room: bool, + ) -> _CalculateChainCover: + """Calculate the chain cover for `batch_size` events, ordered by + `(room_id, depth, stream)`. + + Args: + txn, + last_room_id, last_depth, last_stream: The `(room_id, depth, stream)` + tuple to fetch results after. + batch_size: The maximum number of events to process. If None then + no limit. + single_room: Whether to calculate the index for just the given + room. + """ + + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("events.room_id", last_room_id), + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + extra_clause = "" + if single_room: + extra_clause = "AND events.room_id = ?" + tuple_args.append(last_room_id) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering, + events.room_id + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + %(extra)s + ORDER BY events.room_id, topological_ordering, stream_ordering + %(limit)s + """ % { + "tuple_cmp": tuple_clause, + "limit": "LIMIT ?" if batch_size is not None else "", + "extra": extra_clause, + } + + if batch_size is not None: + tuple_args.append(batch_size) + + txn.execute(sql, tuple_args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: row[5] for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + # Calculate the new last position we've processed up to. + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + new_last_room_id = rows[-1][5] if rows else "" # type: str + + # Map from room_id to last depth/stream_ordering processed for the room, + # excluding the last room (which we're likely still processing). We also + # need to include the room passed in if it's not included in the result + # set (as we then know we've processed all events in said room). + # + # This is the set of rooms that we can now safely flip the + # `has_auth_chain_index` bit for. + finished_rooms = { + row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id + } + if last_room_id not in finished_rooms and last_room_id != new_last_room_id: + finished_rooms[last_room_id] = (last_depth, last_stream) + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + PersistEventsStore._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + return _CalculateChainCover( + room_id=new_last_room_id, + depth=new_last_depth, + stream=new_last_stream, + processed_count=count, + finished_room_map=finished_rooms, + ) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index ff67a73749..0c46ad595b 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple from twisted.trial import unittest @@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): login.register_servlets, ] - def test_background_update(self): - """Test that the background update to calculate auth chains for historic - rooms works correctly. - """ - - # Create a room - user_id = self.register_user("foo", "pass") - token = self.login("foo", "pass") - room_id = self.helper.create_room_as(user_id, tok=token) - requester = create_requester(user_id) + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.requester = create_requester(self.user_id) - store = self.hs.get_datastore() + def _generate_room(self) -> Tuple[str, List[Set[str]]]: + """Insert a room without a chain cover index. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Mark the room as not having a chain cover index self.get_success( - store.db_pool.simple_update( + self.store.db_pool.simple_update( table="rooms", keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": False}, @@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Create a fork in the DAG with different events. event_handler = self.hs.get_event_creation_handler() - latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + latest_event_ids = self.get_success( + self.store.get_prev_events_for_room(room_id) + ) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state1 = list(self.get_success(context.get_current_state_ids()).values()) + state1 = set(self.get_success(context.get_current_state_ids()).values()) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state2 = list(self.get_success(context.get_current_state_ids()).values()) + state2 = set(self.get_success(context.get_current_state_ids()).values()) # Delete the chain cover info. @@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") - self.get_success(store.db_pool.runInteraction("test", _delete_tables)) + self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) + + return room_id, [state1, state2] + + def test_background_update_single_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() # Insert and run the background update. self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag - store.db_pool.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - store.db_pool.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - store.db_pool.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Test that the `has_auth_chain_index` has been set - self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( - store.db_pool.runInteraction( + self.store.db_pool.runInteraction( "test", - store._get_auth_chain_difference_using_cover_index_txn, + self.store._get_auth_chain_difference_using_cover_index_txn, room_id, - [state1, state2], + states, + ) + ) + + def test_background_update_multiple_rooms(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + # Create a room + room_id1, states1 = self._generate_room() + room_id2, states2 = self._generate_room() + room_id3, states2 = self._generate_room() + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id1, + states1, ) ) + + def test_background_update_single_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id, + states, + ) + ) + + def test_background_update_multiple_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create the rooms + room_id1, _ = self._generate_room() + room_id2, _ = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id1, event_type="m.test", body={"index": i}, tok=self.token + ) + + for i in range(0, 150): + self.helper.send_state( + room_id2, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) -- cgit 1.5.1 From b5dea8702d9b799a15d5b8fb90e82497b8c17f63 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jan 2021 18:03:33 +0000 Subject: Fix test failure due to bad merge 0dd2649c1 (#9112) changed the signature of `auth_via_oidc`. Meanwhile, 26d10331e (#9091) introduced a new test which relied on the old signature of `auth_via_oidc`. The two branches were never tested together until they landed in develop. --- tests/rest/client/v2_alpha/test_auth.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 3e8661f9b9..a6488a3d29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc("wrong_user", ui_auth_session_id=session_id) + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) # that should return a failure message self.assertSubstring("We were unable to validate", channel.text_body) -- cgit 1.5.1 From 02070c69faa47bf6aef280939c2d5f32cbcb9f25 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Jan 2021 14:52:49 +0000 Subject: Fix bugs in handling clientRedirectUrl, and improve OIDC tests (#9127, #9128) * Factor out a common TestHtmlParser Looks like I'm doing this in a few different places. * Improve OIDC login test Complete the OIDC login flow, rather than giving up halfway through. * Ensure that OIDC login works with multiple OIDC providers * Fix bugs in handling clientRedirectUrl - don't drop duplicate query-params, or params with no value - allow utf-8 in query-params --- changelog.d/9127.feature | 1 + changelog.d/9128.bugfix | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 2 +- synapse/rest/synapse/client/pick_idp.py | 4 +- tests/rest/client/v1/test_login.py | 146 ++++++++++++++++++++------------ tests/rest/client/v1/utils.py | 62 ++++++++------ tests/server.py | 2 +- tests/test_utils/html_parsers.py | 53 ++++++++++++ 9 files changed, 189 insertions(+), 86 deletions(-) create mode 100644 changelog.d/9127.feature create mode 100644 changelog.d/9128.bugfix create mode 100644 tests/test_utils/html_parsers.py (limited to 'tests') diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9127.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix new file mode 100644 index 0000000000..f87b9fb9aa --- /dev/null +++ b/changelog.d/9128.bugfix @@ -0,0 +1 @@ +Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 18cd2b62f0..0e98db22b3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler): @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({param_name: param}) + query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) + query.append((param_name, param)) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 5e5fda7b2f..ba686d74b2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -85,7 +85,7 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._providers = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } + } # type: Dict[str, OidcProvider] async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index e5b720bbca..9550b82998 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource): self._server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) idp = parse_string(request, "idp", required=False) # if we need to pick an IdP, do so diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 73a009efd1..2d25490374 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,9 +15,8 @@ import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Any, Dict, Union +from urllib.parse import urlencode from mock import Mock @@ -38,6 +37,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -69,6 +69,12 @@ TEST_SAML_METADATA = """ LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): }, } + # default OIDC provider config["oidc_config"] = TEST_OIDC_CONFIG + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] return config def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + d = super().create_resource_dict() d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) return d def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" - client_redirect_url = "https://x?" - # first hit the redirect url, which should redirect to our idp picker channel = self.make_request( "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class - class FormPageParser(HTMLParser): - def __init__(self): - super().__init__() - - # the values of the hidden inputs: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] - - # the values of the radio buttons - self.radios = [] # type: List[Optional[str]] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "input": - if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": - self.radios.append(attr_dict["value"]) - elif attr_dict["type"] == "hidden": - input_name = attr_dict["name"] - assert input_name - self.hiddens[input_name] = attr_dict["value"] - - def error(_, message): - self.fail(message) - - p = FormPageParser() + p = TestHtmlParser() p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) - self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" - client_redirect_url = "https://x?" channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) @@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri = cas_uri_params["service"][0] _, service_uri_query = service_uri.split("?", 1) service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_saml(self): """If SAML is chosen, should redirect to the SAML server""" - client_redirect_url = "https://x?" - channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) @@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # the RelayState is used to carry the client redirect url saml_uri_params = urllib.parse.parse_qs(saml_uri_query) relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, client_redirect_url) + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_oidc(self): + def test_login_via_oidc(self): """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - client_redirect_url = "https://x?" + # pick the default OIDC provider channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) @@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) self.assertEqual( self._get_value_from_macaroon(macaroon, "client_redirect_url"), - client_redirect_url, + TEST_CLIENT_REDIRECT_URL, ) + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + self.assertTrue( + channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") + ) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( @@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # whitelist this client URI so we redirect straight to it rather than # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + config["sso"] = {"client_whitelist": ["https://x"]} return config def create_resource_dict(self) -> Dict[str, Resource]: @@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self): """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" # do the start of the login flow channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL ) # that should redirect to the username picker @@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase): session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) @@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] + login_token = params[2][1] # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c6647dbe08..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -20,8 +20,7 @@ import json import re import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -35,6 +34,7 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -440,10 +440,36 @@ class RestHelper: # param that synapse passes to the IdP via query params, as well as the cookie # that synapse passes to the client. - oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + oauth_uri_path, _ = oauth_uri.split("?", 1) assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( "unexpected SSO URI " + oauth_uri_path ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, @@ -456,9 +482,9 @@ class RestHelper: expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", user_info_dict), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -542,25 +568,7 @@ class RestHelper: channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. - class ConfirmationPageParser(HTMLParser): - def __init__(self): - super().__init__() - - self.links = [] # type: List[str] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "a": - href = attr_dict["href"] - if href: - self.links.append(href) - - def error(_, message): - raise AssertionError(message) - - p = ConfirmationPageParser() + p = TestHtmlParser() p.feed(channel.text_body) p.close() assert len(p.links) == 1, "not exactly one link in confirmation page" @@ -570,6 +578,8 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } diff --git a/tests/server.py b/tests/server.py index 5a1b66270f..5a85d5fe7f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -74,7 +74,7 @@ class FakeChannel: return int(self.result["code"]) @property - def headers(self): + def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") h = Headers() diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py new file mode 100644 index 0000000000..ad563eb3f0 --- /dev/null +++ b/tests/test_utils/html_parsers.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from html.parser import HTMLParser +from typing import Dict, Iterable, List, Optional, Tuple + + +class TestHtmlParser(HTMLParser): + """A generic HTML page parser which extracts useful things from the HTML""" + + def __init__(self): + super().__init__() + + # a list of links found in the doc + self.links = [] # type: List[str] + + # the values of any hidden s: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of any radio buttons: map from name to list of values + self.radios = {} # type: Dict[str, List[Optional[str]]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + elif tag == "input": + input_name = attr_dict.get("name") + if attr_dict["type"] == "radio": + assert input_name + self.radios.setdefault(input_name, []).append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + raise AssertionError(message) -- cgit 1.5.1 From 6633a4015a7b4ba60f87c5e6f979a9c9d8f9d8fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Jan 2021 15:47:59 +0000 Subject: Allow moving account data and receipts streams off master (#9104) --- changelog.d/9104.feature | 1 + synapse/app/generic_worker.py | 15 +- synapse/config/workers.py | 18 +- synapse/handlers/account_data.py | 144 ++++++++++++++++ synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 27 ++- synapse/handlers/room_member.py | 7 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/account_data.py | 187 +++++++++++++++++++++ synapse/replication/slave/storage/_base.py | 10 +- synapse/replication/slave/storage/account_data.py | 40 +---- synapse/replication/slave/storage/receipts.py | 35 +--- synapse/replication/tcp/handler.py | 19 +++ synapse/rest/client/v2_alpha/account_data.py | 22 +-- synapse/rest/client/v2_alpha/tags.py | 11 +- synapse/server.py | 5 + synapse/storage/databases/main/__init__.py | 10 +- synapse/storage/databases/main/account_data.py | 107 +++++++++--- synapse/storage/databases/main/deviceinbox.py | 4 +- .../storage/databases/main/event_push_actions.py | 92 +++++----- synapse/storage/databases/main/events_worker.py | 8 +- synapse/storage/databases/main/receipts.py | 108 ++++++++---- .../main/schema/delta/59/06shard_account_data.sql | 20 +++ .../delta/59/06shard_account_data.sql.postgres | 32 ++++ synapse/storage/databases/main/tags.py | 10 +- synapse/storage/util/id_generators.py | 84 +++++---- tests/storage/test_id_generators.py | 112 +++++++++++- 27 files changed, 855 insertions(+), 280 deletions(-) create mode 100644 changelog.d/9104.feature create mode 100644 synapse/replication/http/account_data.py create mode 100644 synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres (limited to 'tests') diff --git a/changelog.d/9104.feature b/changelog.d/9104.feature new file mode 100644 index 0000000000..1c4f88bce9 --- /dev/null +++ b/changelog.d/9104.feature @@ -0,0 +1 @@ +Add experimental support for moving off receipts and account data persistence off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cb202bda44..e60988fa4a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import ( ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory +from synapse.rest.client.v2_alpha import ( + account_data, + groups, + read_marker, + receipts, + room_keys, + sync, + tags, + user_directory, +) from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account_data import ( @@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer): room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) room_keys.register_servlets(self, resource) + tags.register_servlets(self, resource) + account_data.register_servlets(self, resource) + receipts.register_servlets(self, resource) + read_marker.register_servlets(self, resource) SendToDeviceRestServlet(self).register(resource) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 364583f48b..f10e33f7b8 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -56,6 +56,12 @@ class WriterLocations: to_device = attr.ib( default=["master"], type=List[str], converter=_instance_to_list_converter, ) + account_data = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) + receipts = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -127,7 +133,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing", "to_device"): + for stream in ("events", "typing", "to_device", "account_data", "receipts"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -141,6 +147,16 @@ class WorkerConfig(Config): "Must only specify one instance to handle `to_device` messages." ) + if len(self.writers.account_data) != 1: + raise ConfigError( + "Must only specify one instance to handle `account_data` messages." + ) + + if len(self.writers.receipts) != 1: + raise ConfigError( + "Must only specify one instance to handle `receipts` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 341135822e..b1a5df9638 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +13,157 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import TYPE_CHECKING, List, Tuple +from synapse.replication.http.account_data import ( + ReplicationAddTagRestServlet, + ReplicationRemoveTagRestServlet, + ReplicationRoomAccountDataRestServlet, + ReplicationUserAccountDataRestServlet, +) from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.app.homeserver import HomeServer +class AccountDataHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() + + self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) + self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) + self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) + self._account_data_writers = hs.config.worker.writers.account_data + + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_to_room( + user_id, room_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + + return max_stream_id + else: + response = await self._room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_for_user( + user_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: + """Add a tag to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_tag_to_room( + user_id, room_id, tag, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._add_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + content=content, + ) + return response["max_stream_id"] + + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: + """Remove a tag from a room for a user. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_tag_from_room( + user_id, room_id, tag + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._remove_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + ) + return response["max_stream_id"] + + class AccountDataEventSource: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index a7550806e6..6bb2fd936b 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler): super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() + self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") - self.notifier = hs.get_notifier() async def received_client_read_marker( self, room_id: str, user_id: str, event_id: str @@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler): if should_update: content = {"event_id": event_id} - max_id = await self.store.add_account_data_to_room( + await self.account_data_handler.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a9abdf42e0..cc21fc2284 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.receipt", self._received_remote_receipt - ) + + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the receipts stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + + # If we can handle the receipt EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + hs.get_federation_registry().register_edu_handler( + "m.receipt", self._received_remote_receipt + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.receipt", hs.config.worker.writers.receipts, + ) + self.clock = self.hs.get_clock() self.state = hs.get_state_handler() @@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler): if not is_new: return - await self.federation.send_read_receipt(receipt) + if self.federation_sender: + await self.federation_sender.send_read_receipt(receipt) class ReceiptEventSource: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cb5a29bc7e..e001e418f9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.account_data_handler = hs.get_account_data_handler() self.member_linearizer = Linearizer(name="member") @@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): direct_rooms[key].append(new_room_id) # Save back to user's m.direct account data - await self.store.add_account_data_for_user( + await self.account_data_handler.add_account_data_for_user( user_id, AccountDataTypes.DIRECT, direct_rooms ) break @@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) + await self.account_data_handler.add_tag_to_room( + user_id, new_room_id, tag, tag_content + ) async def update_membership( self, diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a84a064c8d..dd527e807f 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -15,6 +15,7 @@ from synapse.http.server import JsonResource from synapse.replication.http import ( + account_data, devices, federation, login, @@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource): presence.register_servlets(hs, self) membership.register_servlets(hs, self) streams.register_servlets(hs, self) + account_data.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py new file mode 100644 index 0000000000..52d32528ee --- /dev/null +++ b/synapse/replication/http/account_data.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): + """Add user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): + """Add room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "add_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddTagRestServlet(ReplicationEndpoint): + """Add tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_tag/:user_id/:room_id/:tag + + { + "content": { ... }, + } + + """ + + NAME = "add_tag" + PATH_ARGS = ("user_id", "room_id", "tag") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, tag): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_tag_to_room( + user_id, room_id, tag, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRemoveTagRestServlet(ReplicationEndpoint): + """Remove tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag + + {} + + """ + + NAME = "remove_tag" + PATH_ARGS = ( + "user_id", + "room_id", + "tag", + ) + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag): + + return {} + + async def _handle_request(self, request, user_id, room_id, tag): + max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + + return 200, {"max_stream_id": max_stream_id} + + +def register_servlets(hs, http_server): + ReplicationUserAccountDataRestServlet(hs).register(http_server) + ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddTagRestServlet(hs).register(http_server) + ReplicationRemoveTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d0089fe06c..693c9ab901 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) # type: Optional[MultiWriterIdGenerator] diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 4268565fc8..21afe5f155 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -15,47 +15,9 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "account_data", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self): - return self._account_data_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - if not row.room_id: - self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id) - ) - self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) - self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type) - ) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6195917376..3dfdd9961d 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,43 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) - - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): - self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) - ) - self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) - self.get_receipts_for_room.invalidate((room_id, receipt_type)) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(instance_name, token) - for row in rows: - self.invalidate_caches_for_receipt( - row.room_id, row.receipt_type, row.user_id - ) - self._receipts_stream_cache.entity_has_changed(row.room_id, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1f89249475..317796d5e0 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + AccountDataStream, BackfillStream, CachesStream, EventsStream, FederationStream, + ReceiptsStream, Stream, + TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -132,6 +135,22 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + # Only add AccountDataStream and TagAccountDataStream as a source on the + # instance in charge of account_data persistence. + if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._streams_to_replicate.append(stream) + + continue + + if isinstance(stream, ReceiptsStream): + # Only add ReceiptsStream as a source on the instance in charge of + # receipts. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 87a5b1b86b..3f28c0bc3e 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = await self.store.add_account_data_for_user( - user_id, account_data_type, body - ) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = await self.store.add_account_data_to_room( + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return 200, {} async def on_GET(self, request, user_id, room_id, account_data_type): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index bf3a79db44..a97cd66c52 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -58,8 +58,7 @@ class TagServlet(RestServlet): def __init__(self, hs): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request) @@ -68,9 +67,7 @@ class TagServlet(RestServlet): body = parse_json_object_from_request(request) - max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_tag_to_room(user_id, room_id, tag, body) return 200, {} @@ -79,9 +76,7 @@ class TagServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.remove_tag_from_room(user_id, room_id, tag) return 200, {} diff --git a/synapse/server.py b/synapse/server.py index d4c235cda5..9cdda83aa1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.admin import AdminHandler @@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_module_api(self) -> ModuleApi: return ModuleApi(self, self.get_auth_handler()) + @cache_in_self + def get_account_data_handler(self) -> AccountDataHandler: + return AccountDataHandler(self) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index c4de07a0a8..ae561a2da3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,9 +160,13 @@ class DataStore( database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index bad8260892..68896f34af 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): +class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) + + self._account_data_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[ + ("room_account_data", "instance_name", "stream_id"), + ("room_tags_revisions", "instance_name", "stream_id"), + ("account_data", "instance_name", "stream_id"), + ], + sequence_name="account_data_sequence", + writers=hs.config.worker.writers.account_data, + ) + else: + self._can_write_to_account_data = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + else: + self._account_data_id_gen = SlavedIdTracker( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max @@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): super().__init__(database, db_conn, hs) - @abc.abstractmethod - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream Returns: int """ - raise NotImplementedError() + return self._account_data_id_gen.get_current_token() @cached() async def get_account_data_for_user( @@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): ) ) - -class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self) -> int: - """Get the current max stream id for the private user data stream - - Returns: - The maximum stream ID. - """ - return self._account_data_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type) + ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: @@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "add_user_account_data", @@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore): # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + + +class AccountDataStore(AccountDataWorkerStore): + pass diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 58d3f71e45..31f70ac5ef 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): db=database, stream_name="to_device", instance_name=self._instance_name, - table="device_inbox", - instance_column="instance_name", - id_column="stream_id", + tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e5c03cc609..1b657191a9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore): (rotate_to_stream_ordering,), ) + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" @@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4732685f6e..71d823be72 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="events", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_stream_seq", writers=hs.config.worker.writers.events, ) @@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="backfill", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_backfill_stream_seq", positive=False, writers=hs.config.worker.writers.events, diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 1e7949a323..e0e57f0578 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_max_receipt_stream_id` which can be called in the initializer. - """ - +class ReceiptsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_receipts = ( + self._instance_name in hs.config.worker.writers.receipts + ) + + self._receipts_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[("receipts_linearized", "instance_name", "stream_id")], + sequence_name="receipts_sequence", + writers=hs.config.worker.writers.receipts, + ) + else: + self._can_write_to_receipts = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + else: + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ - raise NotImplementedError() + return self._receipts_id_gen.get_current_token() @cached() async def get_users_with_read_receipts_in_room(self, room_id): @@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - -class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" + def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self.get_last_receipt_event_id_for_user.invalidate( + (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() + return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( self, txn, room_id, receipt_type, user_id, event_id, data, stream_id @@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ + assert self._can_write_to_receipts + res = self.db_pool.simple_select_one_txn( txn, table="events", @@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore): ) return None - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + self.invalidate_caches_for_receipt, room_id, receipt_type, user_id ) txn.call_after( self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) - txn.call_after( - self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type), - ) - self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", @@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore): Automatically does conversion between linearized and graph representations. """ + assert self._can_write_to_receipts + if not event_ids: return None @@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore): async def insert_graph_receipt( self, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, @@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore): def insert_graph_receipt_txn( self, txn, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, @@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "data": json_encoder.encode(data), }, ) + + +class ReceiptsStore(ReceiptsWorkerStore): + pass diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql new file mode 100644 index 0000000000..46abf8d562 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_account_data ADD COLUMN instance_name TEXT; +ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT; +ALTER TABLE account_data ADD COLUMN instance_name TEXT; + +ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres new file mode 100644 index 0000000000..4a6e6c74f5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS account_data_sequence; + +-- We need to take the max across all the account_data tables as they share the +-- ID generator +SELECT setval('account_data_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data), + (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions), + (SELECT COALESCE(MAX(stream_id), 1) FROM account_data) + ) +)); + +CREATE SEQUENCE IF NOT EXISTS receipts_sequence; + +SELECT setval('receipts_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized +)); diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 74da9c49f2..50067eabfc 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore): ) return {row["tag"]: db_to_json(row["content"]) for row in rows} - -class TagsStore(TagsWorkerStore): async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: @@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): @@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data def remove_tag_txn(txn, next_id): sql = ( @@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore): room_id: The ID of the room. next_id: The the revision to advance to. """ + assert self._can_write_to_account_data txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore): # which stream_id ends up in the table, as long as it is higher # than the id that the client has. pass + + +class TagsStore(TagsWorkerStore): + pass diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 133c0e7a28..39a3ab1162 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import logging import threading from collections import deque from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import attr from typing_extensions import Deque @@ -186,11 +186,12 @@ class MultiWriterIdGenerator: Args: db_conn db - stream_name: A name for the stream. + stream_name: A name for the stream, for use in the `stream_positions` + table. (Does not need to be the same as the replication stream name) instance_name: The name of this instance. - table: Database table associated with stream. - instance_column: Column that stores the row's writer's instance name - id_column: Column that stores the stream ID. + tables: List of tables associated with the stream. Tuple of table + name, column name that stores the writer's instance name, and + column name that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. writers: A list of known writers to use to populate current positions @@ -206,9 +207,7 @@ class MultiWriterIdGenerator: db: DatabasePool, stream_name: str, instance_name: str, - table: str, - instance_column: str, - id_column: str, + tables: List[Tuple[str, str, str]], sequence_name: str, writers: List[str], positive: bool = True, @@ -260,15 +259,16 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. - self._sequence_gen.check_consistency( - db_conn, table=table, id_column=id_column, positive=positive - ) + for table, _, id_column in tables: + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) # This goes and fills out the above state from the database. - self._load_current_ids(db_conn, table, instance_column, id_column) + self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, table: str, instance_column: str, id_column: str + self, db_conn, tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -306,17 +306,22 @@ class MultiWriterIdGenerator: # We add a GREATEST here to ensure that the result is always # positive. (This can be a problem for e.g. backfill streams where # the server has never backfilled). - sql = """ - SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) - FROM %(table)s - """ % { - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - cur.execute(sql) - (stream_id,) = cur.fetchone() - self._persisted_upto_position = stream_id + max_stream_id = 1 + for table, _, id_column in tables: + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + + max_stream_id = max(max_stream_id, stream_id) + + self._persisted_upto_position = max_stream_id else: # If we have a min_stream_id then we pull out everything greater # than it from the DB so that we can prefill @@ -329,21 +334,28 @@ class MultiWriterIdGenerator: # stream positions table before restart (or the stream position # table otherwise got out of date). - sql = """ - SELECT %(instance)s, %(id)s FROM %(table)s - WHERE ? %(cmp)s %(id)s - """ % { - "id": id_column, - "table": table, - "instance": instance_column, - "cmp": "<=" if self._positive else ">=", - } - cur.execute(sql, (min_stream_id * self._return_factor,)) - self._persisted_upto_position = min_stream_id + rows = [] + for table, instance_column, id_column in tables: + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + cur.execute(sql, (min_stream_id * self._return_factor,)) + + rows.extend(cur) + + # Sort so that we handle rows in order for each instance. + rows.sort() + with self._lock: - for (instance, stream_id,) in cur: + for (instance, stream_id,) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index cc0612cf65..3e2fd4da01 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, ) @@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, positive=False, @@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) + + +class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar1 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + txn.execute( + """ + CREATE TABLE foobar2 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + stream_name="test_stream", + instance_name=instance_name, + tables=[ + ("foobar1", "instance_name", "stream_id"), + ("foobar2", "instance_name", "stream_id"), + ], + sequence_name="foobar_seq", + writers=writers, + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def _insert_rows( + self, + table: str, + instance_name: str, + number: int, + update_stream_table: bool = True, + ): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), + (instance_name,), + ) + if update_stream_table: + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def test_load_existing_stream(self): + """Test creating ID gens with multiple tables that have rows from after + the position in `stream_positions` table. + """ + self._insert_rows("foobar1", "first", 3) + self._insert_rows("foobar2", "second", 3) + self._insert_rows("foobar2", "second", 1, update_stream_table=False) + + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) + + # The first ID gen will notice that it can advance its token to 7 as it + # has no in progress writes... + self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + + # ... but the second ID gen doesn't know that. + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) -- cgit 1.5.1 From fa50e4bf4ddcb8e98d44700513a28c490f80f02b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 12:30:41 +0000 Subject: Give `public_baseurl` a default value (#9159) --- changelog.d/9159.feature | 1 + docs/sample_config.yaml | 31 +++++++++++++++++-------------- synapse/api/urls.py | 2 -- synapse/config/_base.py | 11 ++++++----- synapse/config/emailconfig.py | 8 -------- synapse/config/oidc_config.py | 2 -- synapse/config/registration.py | 21 ++++----------------- synapse/config/saml2_config.py | 2 -- synapse/config/server.py | 24 +++++++++++++++--------- synapse/config/sso.py | 13 +++++-------- synapse/handlers/identity.py | 2 -- synapse/rest/well_known.py | 4 ---- tests/rest/test_well_known.py | 9 --------- tests/utils.py | 1 - 14 files changed, 48 insertions(+), 83 deletions(-) create mode 100644 changelog.d/9159.feature (limited to 'tests') diff --git a/changelog.d/9159.feature b/changelog.d/9159.feature new file mode 100644 index 0000000000..b7748757de --- /dev/null +++ b/changelog.d/9159.feature @@ -0,0 +1 @@ +Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae995efe9b..7fdd798d70 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid # #web_client_location: https://riot.example.com/ -# The public-facing base URL that clients use to access this HS -# (not including _matrix/...). This is the same URL a user would -# enter into the 'custom HS URL' field on their client. If you -# use synapse with a reverse proxy, this should be the URL to reach -# synapse via the proxy. +# The public-facing base URL that clients use to access this Homeserver (not +# including _matrix/...). This is the same URL a user might enter into the +# 'Custom Homeserver URL' field on their client. If you use Synapse with a +# reverse proxy, this should be the URL to reach Synapse via the proxy. +# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see +# 'listeners' below). +# +# If this is left unset, it defaults to 'https:///'. (Note that +# that will not work unless you configure Synapse or a reverse-proxy to listen +# on port 443.) # #public_baseurl: https://example.com/ @@ -1150,8 +1155,9 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -1242,8 +1248,7 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client. -# This setting is ignored unless public_baseurl is also set.) +# (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -1268,8 +1273,6 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # -# If a delegate is specified, the config option public_baseurl must also be filled out. -# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1901,9 +1904,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 6379c86dde..e36aeef31f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,8 +42,6 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.public_baseurl is None: - raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 2931a88207..94144efc87 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -252,11 +252,12 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update({"format_ts": _format_ts_filter}) - if self.public_baseurl: - env.filters.update( - {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} - ) + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) for filename in filenames: # Load the template diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d4328c46b9..6a487afd34 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,11 +166,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - # public_baseurl is required to build password reset and validation links that - # will be emailed to users - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -269,9 +264,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 80a24cfbc9..df55367434 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -43,8 +43,6 @@ class OIDCConfig(Config): raise ConfigError(e.message) from e public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" @property diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 740c3fc1b1..4bfc69cb7a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,10 +49,6 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - template_dir = config.get("template_dir") if not template_dir: @@ -109,13 +105,6 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: - raise ConfigError( - "The configuration option `public_baseurl` is required if " - "`account_threepid_delegate.msisdn` is set, such that " - "clients know where to submit validation tokens to. Please " - "configure `public_baseurl`." - ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -240,8 +229,9 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -332,8 +322,7 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client. - # This setting is ignored unless public_baseurl is also set.) + # (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -358,8 +347,6 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # - # If a delegate is specified, the config option public_baseurl must also be filled out. - # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 7b97d4f114..f33dfa0d6a 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,8 +189,6 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 7242a4aa8e..75ba161f35 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,7 +161,11 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") + self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( + self.server_name, + ) + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -317,9 +321,6 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -740,11 +741,16 @@ class ServerConfig(Config): # #web_client_location: https://riot.example.com/ - # The public-facing base URL that clients use to access this HS - # (not including _matrix/...). This is the same URL a user would - # enter into the 'custom HS URL' field on their client. If you - # use synapse with a reverse proxy, this should be the URL to reach - # synapse via the proxy. + # The public-facing base URL that clients use to access this Homeserver (not + # including _matrix/...). This is the same URL a user might enter into the + # 'Custom Homeserver URL' field on their client. If you use Synapse with a + # reverse proxy, this should be the URL to reach Synapse via the proxy. + # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see + # 'listeners' below). + # + # If this is left unset, it defaults to 'https:///'. (Note that + # that will not work unless you configure Synapse or a reverse-proxy to listen + # on port 443.) # #public_baseurl: https://example.com/ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 366f0d4698..59be825532 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,11 +64,8 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - # public_baseurl is an optional setting, so we only add the fallback's URL to the - # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -86,9 +83,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c05036ad1f..f61844d688 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - assert self.hs.config.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index f591cc6c5c..241fe746d9 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,10 +34,6 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): - # if we don't have a public_baseurl, we can't help much here. - if self._config.public_baseurl is None: - return None - result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 14de0921be..c5e44af9f7 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,12 +40,3 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) - - def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - - channel = self.make_request( - "GET", "/.well-known/matrix/client", shorthand=False - ) - - self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 977eeaf6ee..09614093bc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,7 +159,6 @@ def default_config(name, parse=False): "remote": {"per_second": 10000, "burst_count": 10000}, }, "saml2_enabled": False, - "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1 From 0cd2938bc854d947ae8102ded688a626c9fac5b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:15:14 +0000 Subject: Support icons for Identity Providers (#9154) --- changelog.d/9154.feature | 1 + docs/sample_config.yaml | 4 ++ mypy.ini | 1 + synapse/config/oidc_config.py | 20 ++++++ synapse/config/server.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/cas_handler.py | 4 ++ synapse/handlers/oidc_handler.py | 3 + synapse/handlers/room.py | 2 +- synapse/handlers/saml_handler.py | 4 ++ synapse/handlers/sso.py | 5 ++ synapse/http/endpoint.py | 79 --------------------- synapse/res/templates/sso_login_idp_picker.html | 3 + synapse/rest/client/v1/room.py | 3 +- synapse/storage/databases/main/room.py | 6 +- synapse/types.py | 2 +- synapse/util/stringutils.py | 92 +++++++++++++++++++++++++ tests/http/test_endpoint.py | 2 +- 19 files changed, 146 insertions(+), 91 deletions(-) create mode 100644 changelog.d/9154.feature delete mode 100644 synapse/http/endpoint.py (limited to 'tests') diff --git a/changelog.d/9154.feature b/changelog.d/9154.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9154.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7fdd798d70..b49a5da8cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1726,6 +1726,10 @@ saml2_config: # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc:/// +# # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index df55367434..f257fcd412 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -23,6 +23,7 @@ from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -66,6 +67,10 @@ class OIDCConfig(Config): # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc:/// + # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # @@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "properties": { "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -336,9 +342,20 @@ def _parse_oidc_config_dict( config_path + ("idp_id",), ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -366,6 +383,9 @@ class OidcProviderConfig: # user-facing name for this identity provider. idp_name = attr.ib(type=str) + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/config/server.py b/synapse/config/server.py index 75ba161f35..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ba686d74b2..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -271,6 +271,9 @@ class OidcProvider: # user-facing name of this auth provider self.idp_name = provider.idp_name + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@
  • +{% if p.idp_icon %} + +{% endif %}
  • {% endfor %} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6725b03b0..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 284f2ce77c..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) diff --git a/synapse/types.py b/synapse/types.py index 20a43d05bf..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b103c8694c..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest -- cgit 1.5.1 From 7447f197026db570c1c1af240642566b31f81e42 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 21 Jan 2021 12:25:02 +0000 Subject: Prefix idp_id with "oidc-" (#9189) ... to avoid clashes with other SSO mechanisms --- changelog.d/9189.misc | 1 + docs/sample_config.yaml | 13 +++++++++---- synapse/config/oidc_config.py | 28 ++++++++++++++++++++++++---- tests/rest/client/v1/test_login.py | 2 +- 4 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 changelog.d/9189.misc (limited to 'tests') diff --git a/changelog.d/9189.misc b/changelog.d/9189.misc new file mode 100644 index 0000000000..9a5740aac2 --- /dev/null +++ b/changelog.d/9189.misc @@ -0,0 +1 @@ +Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b49a5da8cc..87bfe22237 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1728,7 +1728,9 @@ saml2_config: # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format -# mxc:/// +# mxc:///. (An easy way to obtain such an MXC URI +# is to upload an image to an (unencrypted) room and then copy the "url" +# from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -1814,13 +1816,16 @@ saml2_config: # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are -# advised to migrate to the 'oidc_providers' format. +# advised to migrate to the 'oidc_providers' format. (When doing that migration, +# use 'oidc' for the idp_id to ensure that existing users continue to be +# recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -1844,8 +1849,8 @@ oidc_providers: # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 8cb0c42f36..d58a83be7f 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -69,7 +69,9 @@ class OIDCConfig(Config): # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format - # mxc:/// + # mxc:///. (An easy way to obtain such an MXC URI + # is to upload an image to an (unencrypted) room and then copy the "url" + # from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -155,13 +157,16 @@ class OIDCConfig(Config): # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are - # advised to migrate to the 'oidc_providers' format. + # advised to migrate to the 'oidc_providers' format. (When doing that migration, + # use 'oidc' for the idp_id to ensure that existing users continue to be + # recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -185,8 +190,8 @@ class OIDCConfig(Config): # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -210,6 +215,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { + # TODO: fix the maxLength here depending on what MSC2528 decides + # remember that we prefix the ID given here with `oidc-` "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, "idp_icon": {"type": "string"}, @@ -335,6 +342,8 @@ def _parse_oidc_config_dict( # enforce those limits now. # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") + + # TODO: update this validity check based on what MSC2858 decides. valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") if any(c not in valid_idp_chars for c in idp_id): @@ -348,6 +357,17 @@ def _parse_oidc_config_dict( "idp_id must start with a-z", config_path + ("idp_id",), ) + # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid + # clashes with other mechs (such as SAML, CAS). + # + # We allow "oidc" as an exception so that people migrating from old-style + # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to + # a new-style "oidc_providers" entry without changing the idp_id for their provider + # (and thereby invalidating their user_external_ids data). + + if idp_id != "oidc": + idp_id = "oidc-" + idp_id + # MSC2858 also specifies that the idp_icon must be a valid MXC uri idp_icon = oidc_config.get("idp_icon") if idp_icon is not None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2d25490374..2672ce24c6 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -446,7 +446,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) -- cgit 1.5.1 From c55e62548c0fddd49e7182133880d2ccb03dbb42 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 21 Jan 2021 15:18:46 +0100 Subject: Add tests for List Users Admin API (#9045) --- changelog.d/9045.misc | 1 + synapse/rest/admin/users.py | 21 +++- tests/rest/admin/test_user.py | 223 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 215 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9045.misc (limited to 'tests') diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc new file mode 100644 index 0000000000..7f1886a0de --- /dev/null +++ b/changelog.d/9045.misc @@ -0,0 +1 @@ +Add tests to `test_user.UsersListTestCase` for List Users Admin API. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f39e3d6d5c..86198bab30 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 04599c2fcf..e48f8c1d7b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.types import JsonDict from tests import unittest from tests.test_utils import make_awaitable @@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.user1 = self.register_user( - "user1", "pass1", admin=False, displayname="Name 1" - ) - self.user2 = self.register_user( - "user2", "pass2", admin=False, displayname="Name 2" - ) - def test_no_auth(self): """ Try to list users without authentication. @@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ + self._create_users(1) other_user_token = self.login("user1", "pass1") channel = self.make_request("GET", self.url, access_token=other_user_token) @@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ + self._create_users(2) + channel = self.make_request( "GET", self.url + "?deactivated=true", @@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, channel.json_body["total"]) # Check that all fields are available - for u in channel.json_body["users"]: - self.assertIn("name", u) - self.assertIn("is_guest", u) - self.assertIn("admin", u) - self.assertIn("user_type", u) - self.assertIn("deactivated", u) - self.assertIn("displayname", u) - self.assertIn("avatar_url", u) + self._check_fields(channel.json_body["users"]) def test_search_term(self): """Test that searching for a users works correctly""" @@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that users were returned self.assertTrue("users" in channel.json_body) + self._check_fields(channel.json_body["users"]) users = channel.json_body["users"] # Check that the expected number of users were returned @@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase): u = users[0] self.assertEqual(expected_user_id, u["name"]) + self._create_users(2) + + user1 = "@user1:test" + user2 = "@user2:test" + # Perform search tests - _search_test(self.user1, "er1") - _search_test(self.user1, "me 1") + _search_test(user1, "er1") + _search_test(user1, "me 1") - _search_test(self.user2, "er2") - _search_test(self.user2, "me 2") + _search_test(user2, "er2") + _search_test(user2, "me 2") - _search_test(self.user1, "er1", "user_id") - _search_test(self.user2, "er2", "user_id") + _search_test(user1, "er1", "user_id") + _search_test(user2, "er2", "user_id") # Test case insensitive - _search_test(self.user1, "ER1") - _search_test(self.user1, "NAME 1") + _search_test(user1, "ER1") + _search_test(user1, "NAME 1") - _search_test(self.user2, "ER2") - _search_test(self.user2, "NAME 2") + _search_test(user2, "ER2") + _search_test(user2, "NAME 2") - _search_test(self.user1, "ER1", "user_id") - _search_test(self.user2, "ER2", "user_id") + _search_test(user1, "ER1", "user_id") + _search_test(user2, "ER2", "user_id") _search_test(None, "foo") _search_test(None, "bar") @@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid guests + channel = self.make_request( + "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid deactivated + channel = self.make_request( + "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of users with limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of users with a defined starting point (from) + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of users with a defined starting point and limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ + for u in content: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def _create_users(self, number_users: int): + """ + Create a number of users + Args: + number_users: Number of users to be created + """ + for i in range(1, number_users + 1): + self.register_user( + "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i, + ) + class DeactivateAccountTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From a7882f98874684969910d3a6ed7d85f99114cc45 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Jan 2021 14:53:58 -0500 Subject: Return a 404 if no valid thumbnail is found. (#9163) If no thumbnail of the requested type exists, return a 404 instead of erroring. This doesn't quite match the spec (which does not define what happens if no thumbnail can be found), but is consistent with what Synapse already does. --- changelog.d/9163.bugfix | 1 + synapse/rest/media/v1/_base.py | 3 + synapse/rest/media/v1/thumbnail_resource.py | 236 ++++++++++++++++++---------- tests/rest/media/v1/test_media_storage.py | 25 ++- 4 files changed, 183 insertions(+), 82 deletions(-) create mode 100644 changelog.d/9163.bugfix (limited to 'tests') diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix new file mode 100644 index 0000000000..c51cf6ca80 --- /dev/null +++ b/changelog.d/9163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 31a41e4a27..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -300,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -312,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -321,6 +323,7 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d6880f2e6e..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,7 +16,7 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from twisted.web.http import Request @@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, @@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos, - ) -> dict: + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ae2b32b131..a6c6985173 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): config = self.default_config() config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} config["max_image_pixels"] = 2000000 provider_config = { @@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): + """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) def test_thumbnail_scale(self): + """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) + def test_invalid_type(self): + """An invalid thumbnail type is never available.""" + self._test_thumbnail("invalid", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} + ) + def test_no_thumbnail_crop(self): + """ + Override the config to generate only scaled thumbnails, but request a cropped one. + """ + self._test_thumbnail("crop", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} + ) + def test_no_thumbnail_scale(self): + """ + Override the config to generate only cropped thumbnails, but request a scaled one. + """ + self._test_thumbnail("scale", None, False) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( -- cgit 1.5.1 From 056327457ff471495741a539e99c840ed54afccd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 Jan 2021 19:44:08 +0000 Subject: Fix chain cover update to handle events with duplicate auth events (#9210) --- changelog.d/9210.bugfix | 1 + synapse/util/iterutils.py | 2 +- tests/util/test_itertools.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9210.bugfix (limited to 'tests') diff --git a/changelog.d/9210.bugfix b/changelog.d/9210.bugfix new file mode 100644 index 0000000000..f9e0765570 --- /dev/null +++ b/changelog.d/9210.bugfix @@ -0,0 +1 @@ +Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 6ef2b008a4..8d2411513f 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -78,7 +78,7 @@ def sorted_topologically( if node not in degree_map: continue - for edge in edges: + for edge in set(edges): if edge in degree_map: degree_map[node] += 1 diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 522c8061f9..1ef0af8e8f 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -92,3 +92,15 @@ class SortTopologically(TestCase): # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # always get the same one. self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_duplicates(self): + "Test that a graph with duplicate edges work" + graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) + + def test_multiple_paths(self): + "Test that a graph with multiple paths between two nodes work" + graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] + + self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From 6f7417c3db54c9545e93b0428303f29973468d39 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jan 2021 07:27:16 -0500 Subject: Handle missing content keys when calculating presentable names. (#9165) Treat the content as untrusted and do not assume it is of the proper form. --- changelog.d/9165.bugfix | 1 + synapse/push/presentable_names.py | 26 ++-- tests/push/test_presentable_names.py | 229 +++++++++++++++++++++++++++++++++ tests/push/test_push_rule_evaluator.py | 2 +- 4 files changed, 242 insertions(+), 16 deletions(-) create mode 100644 changelog.d/9165.bugfix create mode 100644 tests/push/test_presentable_names.py (limited to 'tests') diff --git a/changelog.d/9165.bugfix b/changelog.d/9165.bugfix new file mode 100644 index 0000000000..58db22f484 --- /dev/null +++ b/changelog.d/9165.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push. diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 7e50341d74..04c2c1482c 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -17,7 +17,7 @@ import logging import re from typing import TYPE_CHECKING, Dict, Iterable, Optional -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.types import StateMap @@ -63,7 +63,7 @@ async def calculate_room_name( m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) - if m_room_name and m_room_name.content and m_room_name.content["name"]: + if m_room_name and m_room_name.content and m_room_name.content.get("name"): return m_room_name.content["name"] # does it have a canonical alias? @@ -74,15 +74,11 @@ async def calculate_room_name( if ( canon_alias and canon_alias.content - and canon_alias.content["alias"] + and canon_alias.content.get("alias") and _looks_like_an_alias(canon_alias.content["alias"]) ): return canon_alias.content["alias"] - # at this point we're going to need to search the state by all state keys - # for an event type, so rearrange the data structure - room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) - if not fallback_to_members: return None @@ -94,7 +90,7 @@ async def calculate_room_name( if ( my_member_event is not None - and my_member_event.content["membership"] == "invite" + and my_member_event.content.get("membership") == Membership.INVITE ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: inviter_member_event = await store.get_event( @@ -111,6 +107,10 @@ async def calculate_room_name( else: return "Room Invite" + # at this point we're going to need to search the state by all state keys + # for an event type, so rearrange the data structure + room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) + # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: @@ -120,8 +120,8 @@ async def calculate_room_name( all_members = [ ev for ev in member_events.values() - if ev.content["membership"] == "join" - or ev.content["membership"] == "invite" + if ev.content.get("membership") == Membership.JOIN + or ev.content.get("membership") == Membership.INVITE ] # Sort the member events oldest-first so the we name people in the # order the joined (it should at least be deterministic rather than @@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str: def name_from_member_event(member_event: EventBase) -> str: - if ( - member_event.content - and "displayname" in member_event.content - and member_event.content["displayname"] - ): + if member_event.content and member_event.content.get("displayname"): return member_event.content["displayname"] return member_event.state_key diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py new file mode 100644 index 0000000000..aff563919d --- /dev/null +++ b/tests/push/test_presentable_names.py @@ -0,0 +1,229 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional, Tuple + +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent +from synapse.push.presentable_names import calculate_room_name +from synapse.types import StateKey, StateMap + +from tests import unittest + + +class MockDataStore: + """ + A fake data store which stores a mapping of state key to event content. + (I.e. the state key is used as the event ID.) + """ + + def __init__(self, events: Iterable[Tuple[StateKey, dict]]): + """ + Args: + events: A state map to event contents. + """ + self._events = {} + + for i, (event_id, content) in enumerate(events): + self._events[event_id] = FrozenEvent( + { + "event_id": "$event_id", + "type": event_id[0], + "sender": "@user:test", + "state_key": event_id[1], + "room_id": "#room:test", + "content": content, + "origin_server_ts": i, + }, + RoomVersions.V1, + ) + + async def get_event( + self, event_id: StateKey, allow_none: bool = False + ) -> Optional[FrozenEvent]: + assert allow_none, "Mock not configured for allow_none = False" + + return self._events.get(event_id) + + async def get_events(self, event_ids: Iterable[StateKey]): + # This is cheating since it just returns all events. + return self._events + + +class PresentableNamesTestCase(unittest.HomeserverTestCase): + USER_ID = "@test:test" + OTHER_USER_ID = "@user:test" + + def _calculate_room_name( + self, + events: StateMap[dict], + user_id: str = "", + fallback_to_members: bool = True, + fallback_to_single_member: bool = True, + ): + # This isn't 100% accurate, but works with MockDataStore. + room_state_ids = {k[0]: k[0] for k in events} + + return self.get_success( + calculate_room_name( + MockDataStore(events), + room_state_ids, + user_id or self.USER_ID, + fallback_to_members, + fallback_to_single_member, + ) + ) + + def test_name(self): + """A room name event should be used.""" + events = [ + ((EventTypes.Name, ""), {"name": "test-name"}), + ] + self.assertEqual("test-name", self._calculate_room_name(events)) + + # Check if the event content has garbage. + events = [((EventTypes.Name, ""), {"foo": 1})] + self.assertEqual("Empty Room", self._calculate_room_name(events)) + + events = [((EventTypes.Name, ""), {"name": 1})] + self.assertEqual(1, self._calculate_room_name(events)) + + def test_canonical_alias(self): + """An canonical alias should be used.""" + events = [ + ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}), + ] + self.assertEqual("#test-name:test", self._calculate_room_name(events)) + + # Check if the event content has garbage. + events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})] + self.assertEqual("Empty Room", self._calculate_room_name(events)) + + events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})] + self.assertEqual("Empty Room", self._calculate_room_name(events)) + + def test_invite(self): + """An invite has special behaviour.""" + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}), + ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}), + ] + self.assertEqual("Invite from Other User", self._calculate_room_name(events)) + self.assertIsNone( + self._calculate_room_name(events, fallback_to_single_member=False) + ) + # Ensure this logic is skipped if we don't fallback to members. + self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False)) + + # Check if the event content has garbage. + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}), + ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}), + ] + self.assertEqual("Invite from @user:test", self._calculate_room_name(events)) + + # No member event for sender. + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}), + ] + self.assertEqual("Room Invite", self._calculate_room_name(events)) + + def test_no_members(self): + """Behaviour of an empty room.""" + events = [] + self.assertEqual("Empty Room", self._calculate_room_name(events)) + + # Note that events with invalid (or missing) membership are ignored. + events = [ + ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}), + ((EventTypes.Member, "@foo:test"), {"membership": "foo"}), + ] + self.assertEqual("Empty Room", self._calculate_room_name(events)) + + def test_no_other_members(self): + """Behaviour of a room with no other members in it.""" + events = [ + ( + (EventTypes.Member, self.USER_ID), + {"membership": Membership.JOIN, "displayname": "Me"}, + ), + ] + self.assertEqual("Me", self._calculate_room_name(events)) + + # Check if the event content has no displayname. + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}), + ] + self.assertEqual("@test:test", self._calculate_room_name(events)) + + # 3pid invite, use the other user (who is set as the sender). + events = [ + ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}), + ] + self.assertEqual( + "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID) + ) + + events = [ + ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}), + ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}), + ] + self.assertEqual( + "Inviting email address", + self._calculate_room_name(events, user_id=self.OTHER_USER_ID), + ) + + def test_one_other_member(self): + """Behaviour of a room with a single other member.""" + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}), + ( + (EventTypes.Member, self.OTHER_USER_ID), + {"membership": Membership.JOIN, "displayname": "Other User"}, + ), + ] + self.assertEqual("Other User", self._calculate_room_name(events)) + self.assertIsNone( + self._calculate_room_name(events, fallback_to_single_member=False) + ) + + # Check if the event content has no displayname and is an invite. + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}), + ( + (EventTypes.Member, self.OTHER_USER_ID), + {"membership": Membership.INVITE}, + ), + ] + self.assertEqual("@user:test", self._calculate_room_name(events)) + + def test_other_members(self): + """Behaviour of a room with multiple other members.""" + # Two other members. + events = [ + ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}), + ( + (EventTypes.Member, self.OTHER_USER_ID), + {"membership": Membership.JOIN, "displayname": "Other User"}, + ), + ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}), + ] + self.assertEqual("Other User and @foo:test", self._calculate_room_name(events)) + + # Three or more other members. + events.append( + ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE}) + ) + self.assertEqual("Other User and 2 others", self._calculate_room_name(events)) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 1f4b5ca2ac..4a841f5bb8 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "type": "m.room.history_visibility", "sender": "@user:test", "state_key": "", - "room_id": "@room:test", + "room_id": "#room:test", "content": content, }, RoomVersions.V1, -- cgit 1.5.1 From 4a55d267eef1388690e6781b580910e341358f95 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jan 2021 14:49:39 -0500 Subject: Add an admin API for shadow-banning users. (#9209) This expands the current shadow-banning feature to be usable via the admin API and adds documentation for it. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. --- changelog.d/9209.feature | 1 + docs/admin_api/user_admin_api.rst | 30 ++++++++++++ stubs/txredisapi.pyi | 1 - synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 36 +++++++++++++++ synapse/storage/databases/main/registration.py | 29 ++++++++++++ tests/rest/admin/test_user.py | 64 ++++++++++++++++++++++++++ tests/rest/client/test_shadow_banned.py | 8 +--- 8 files changed, 164 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9209.feature (limited to 'tests') diff --git a/changelog.d/9209.feature b/changelog.d/9209.feature new file mode 100644 index 0000000000..ec926e8eb4 --- /dev/null +++ b/changelog.d/9209.feature @@ -0,0 +1 @@ +Add an admin API endpoint for shadow-banning users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index b3d413cf57..1eb674939e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -760,3 +760,33 @@ The following fields are returned in the JSON response body: - ``total`` - integer - Number of pushers. See also `Client-Server API Spec `_ + +Shadow-banning users +==================== + +Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. +A shadow-banned users receives successful responses to their client-server API requests, +but the events are not propagated into rooms. This can be an effective tool as it +(hopefully) takes longer for the user to realise they are being moderated before +pivoting to another account. + +Shadow-banning a user should be used as a tool of last resort and may lead to confusing +or broken behaviour for the client. A shadow-banned user will not receive any +notification and it is generally more appropriate to ban or kick abusive users. +A shadow-banned user will be unable to contact anyone on the server. + +The API is:: + + POST /_synapse/admin/v1/users//shadow_ban + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must + be local. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bfac6840e6..726454ba31 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,7 +15,6 @@ """Contains *incomplete* type hints for txredisapi. """ - from typing import List, Optional, Type, Union class RedisProtocol: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6f7dc06503..f04740cd38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -230,6 +231,7 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 86198bab30..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -890,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 585b4049d6..0618b4387a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: + """Sets whether a user shadow-banned. + + Args: + user: user ID of the user to test + shadow_banned: true iff the user is to be shadow-banned, false otherwise. + """ + + def set_shadow_banned_txn(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"shadow_banned": shadow_banned}, + ) + # In order for this to apply immediately, clear the cache for this user. + tokens = self.db_pool.simple_select_onecol_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user.to_string()}, + retcol="token", + ) + for token in tokens: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e48f8c1d7b..ee05ee60bc 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2380,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) + + +class ShadowBanRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + channel = self.make_request("POST", self.url) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("POST", self.url, access_token=other_user_token) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that shadow-banning for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + self.assertEqual(400, channel.code, msg=channel.json_body) + + def test_success(self): + """ + Shadow-banning should succeed for an admin. + """ + # The user starts off as not shadow-banned. + other_user_token = self.login("user", "pass") + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertTrue(result.shadow_banned) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index e689c3fbea..0ebdf1415b 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -18,6 +18,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.types import UserID from tests import unittest @@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.store = self.hs.get_datastore() self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) + self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) ) self.other_user_id = self.register_user("otheruser", "pass") -- cgit 1.5.1 From 4937fe3d6be94222b02760866496781f8cc88751 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jan 2021 07:32:17 -0500 Subject: Try to recover from unknown encodings when previewing media. (#9164) Treat unknown encodings (according to lxml) as UTF-8 when generating a preview for HTML documents. This isn't fully accurate, but will hopefully give a reasonable title and summary. --- changelog.d/9164.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 44 +++++++++++++++++++++------ tests/test_preview.py | 29 ++++++++++++++++++ 3 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 changelog.d/9164.bugfix (limited to 'tests') diff --git a/changelog.d/9164.bugfix b/changelog.d/9164.bugfix new file mode 100644 index 0000000000..1c54a256c1 --- /dev/null +++ b/changelog.d/9164.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a632099167..bf3be653aa 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Check whether the URL should be downloaded as oEmbed content instead. - Params: + Args: url: The URL to check. Returns: @@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Request content from an oEmbed endpoint. - Params: + Args: endpoint: The oEmbed API endpoint. url: The URL to pass to the API. @@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource): def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: + """ + Calculate metadata for an HTML document. + + This uses lxml to parse the HTML document into the OG response. If errors + occur during processing of the document, an empty response is returned. + + Args: + body: The HTML document, as bytes. + media_url: The URI used to download the body. + request_encoding: The character encoding of the body, as a string. + + Returns: + The OG response as a dictionary. + """ # If there's no body, nothing useful is going to be found. if not body: return {} from lxml import etree + # Create an HTML parser. If this fails, log and return no metadata. try: parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body, parser) - og = _calc_og(tree, media_uri) + except LookupError: + # blindly consider the encoding as utf-8. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + except Exception as e: + logger.warning("Unable to create HTML parser: %s" % (e,)) + return {} + + def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(body_attempt, parser) + return _calc_og(tree, media_uri) + + # Attempt to parse the body. If this fails, log and return no metadata. + try: + return _attempt_calc_og(body) except UnicodeDecodeError: # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) - og = _calc_og(tree, media_uri) - - return og + return _attempt_calc_og(body.decode("utf-8", "ignore")) -def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: +def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them diff --git a/tests/test_preview.py b/tests/test_preview.py index c19facc1cb..0c6cbbd921 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase): html = "" og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEqual(og, {}) + + def test_invalid_encoding(self): + """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" + html = """ + + Foo + + Some text. + + + """ + og = decode_and_calc_og( + html, "http://example.com/test.html", "invalid-encoding" + ) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) + + def test_invalid_encoding2(self): + """A body which doesn't match the sent character encoding.""" + # Note that this contains an invalid UTF-8 sequence in the title. + html = b""" + + \xff\xff Foo + + Some text. + + + """ + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) -- cgit 1.5.1 From dd8da8c5f6ac525a7456437913a03f68d4504605 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Jan 2021 13:57:31 +0000 Subject: Precompute joined hosts and store in Redis (#9198) --- changelog.d/9198.misc | 1 + stubs/txredisapi.pyi | 12 +++- synapse/config/_base.pyi | 2 + synapse/federation/sender/__init__.py | 50 +++++++++----- synapse/handlers/federation.py | 5 ++ synapse/handlers/message.py | 42 ++++++++++++ synapse/replication/tcp/external_cache.py | 105 ++++++++++++++++++++++++++++++ synapse/replication/tcp/handler.py | 15 +---- synapse/server.py | 30 +++++++++ synapse/state/__init__.py | 11 +++- tests/replication/_base.py | 41 +++++++----- 11 files changed, 265 insertions(+), 49 deletions(-) create mode 100644 changelog.d/9198.misc create mode 100644 synapse/replication/tcp/external_cache.py (limited to 'tests') diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc new file mode 100644 index 0000000000..a6cb77fbb2 --- /dev/null +++ b/changelog.d/9198.misc @@ -0,0 +1 @@ +Precompute joined hosts and store in Redis. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bdc892ec82..618548a305 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,11 +15,21 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Type, Union +from typing import Any, List, Optional, Type, Union class RedisProtocol: def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... + async def set( + self, + key: str, + value: Any, + expire: Optional[int] = None, + pexpire: Optional[int] = None, + only_if_not_exists: bool = False, + only_if_exists: bool = False, + ) -> None: ... + async def get(self, key: str) -> Any: ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..8ba669059a 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -18,6 +18,7 @@ from synapse.config import ( password_auth_providers, push, ratelimiting, + redis, registration, repository, room_directory, @@ -79,6 +80,7 @@ class RootConfig: roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig + redis: redis.RedisConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 604cfd1935..643b26ae6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -142,6 +142,8 @@ class FederationSender: self._wake_destinations_needing_catchup, ) + self._external_cache = hs.get_external_cache() + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -197,22 +199,40 @@ class FederationSender: if not event.internal_metadata.should_proactively_send(): return - try: - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() - ) - except Exception: - logger.exception( - "Failed to calculate hosts in room for event: %s", - event.event_id, + destinations = None # type: Optional[Set[str]] + if not event.prev_event_ids(): + # If there are no prev event IDs then the state is empty + # and so no remote servers in the room + destinations = set() + else: + # We check the external cache for the destinations, which is + # stored per state group. + + sg = await self._external_cache.get( + "event_to_prev_state_group", event.event_id ) - return + if sg: + destinations = await self._external_cache.get( + "get_joined_hosts", str(sg) + ) + + if destinations is None: + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return destinations = { d diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd8de8696d..b6dc7f99b6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dfeab09cd..e2a7d567fa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -432,6 +432,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -939,6 +941,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +982,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644 index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 58d46a5951..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -286,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -303,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - hs=hs, - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/server.py b/synapse/server.py index 9cdda83aa1..9bdd3177d7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.saml_handler import SamlHandler @@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_external_cache(self) -> ExternalCache: + return ExternalCache(self) + + @cache_in_self + def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: + if not self.config.redis.redis_enabled: + return None + + # We only want to import redis module if we're using it, as we have + # `txredisapi` as an optional dependency. + from synapse.replication.tcp.redis import lazyConnection + + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis_host, + self.config.redis_port, + ) + + return lazyConnection( + hs=self, + host=self.config.redis_host, + port=self.config.redis_port, + password=self.config.redis.redis_password, + reconnect=True, + ) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 84f59c7d85..3bd9ff8ca0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -310,6 +310,7 @@ class StateHandler: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None + entry = None else: # otherwise, we'll need to resolve the state across the prev_events. @@ -340,9 +341,13 @@ class StateHandler: current_state_ids=state_ids_before_event, ) - # XXX: can we update the state cache entry for the new state group? or - # could we set a flag on resolve_state_groups_for_events to tell it to - # always make a state group? + # Assign the new state group to the cached state entry. + # + # Note that this can race in that we could generate multiple state + # groups for the same state entry, but that is just inefficient + # rather than dangerous. + if entry and entry.state_group is None: + entry.state_group = state_group_before_event # # now if it's not a state event, we're done diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 3379189785..d5dce1f83f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Fake in memory Redis server that servers can connect to. self._redis_server = FakeRedisPubSubServer() + # We may have an attempt to connect to redis for the external cache already. + self.connect_any_redis_attempts() + store = self.hs.get_datastore() self.database_pool = store.db_pool @@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): fake one. """ clients = self.reactor.tcpClients - self.assertEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") - self.assertEqual(port, 6379) + while clients: + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) - client_to_server_transport = FakeTransport( - server_protocol, self.reactor, client_protocol - ) - client_protocol.makeConnection(client_to_server_transport) - - server_to_client_transport = FakeTransport( - client_protocol, self.reactor, server_protocol - ) - server_protocol.makeConnection(server_to_client_transport) + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) - return client_to_server_transport, server_to_client_transport + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) class TestReplicationDataHandler(GenericWorkerReplicationHandler): @@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol): (channel,) = args self._server.add_subscriber(self) self.send(["subscribe", channel, 1]) + + # Since we use SET/GET to cache things we can safely no-op them. + elif command == b"SET": + self.send("OK") + elif command == b"GET": + self.send(None) else: raise Exception("Unknown command") @@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol): # We assume bytes are just unicode strings. obj = obj.decode("utf-8") + if obj is None: + return "$-1\r\n" if isinstance(obj, str): return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) if isinstance(obj, int): -- cgit 1.5.1 From a737cc27134c50059440ca33510b0baea53b4225 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 12:41:24 +0000 Subject: Implement MSC2858 support (#9183) Fixes #8928. --- changelog.d/9183.feature | 1 + synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 29 ++++++++++++ synapse/config/homeserver.py | 2 + synapse/handlers/sso.py | 23 +++++++--- synapse/http/server.py | 44 ++++++++++++++---- synapse/rest/client/v1/login.py | 55 ++++++++++++++++++++--- tests/rest/client/v1/test_login.py | 92 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 +- 9 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 changelog.d/9183.feature create mode 100644 synapse/config/experimental.py (limited to 'tests') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature new file mode 100644 index 0000000000..2d5c735042 --- /dev/null +++ b/changelog.d/9183.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..3ccea4b02d 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -9,6 +9,7 @@ from synapse.config import ( consent_config, database, emailconfig, + experimental, groups, jwt_config, key, @@ -48,6 +49,7 @@ def path_exists(file_path: str): ... class RootConfig: server: server.ServerConfig + experimental: experimental.ExperimentalConfig tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py new file mode 100644 index 0000000000..b1c1c51e4d --- /dev/null +++ b/synapse/config/experimental.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.types import JsonDict + + +class ExperimentalConfig(Config): + """Config section for enabling experimental features""" + + section = "experimental" + + def read_config(self, config: JsonDict, **kwargs): + experimental = config.get("experimental_features") or {} + + # MSC2858 (multiple SSO identity providers) + self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4bd2b3587b..64a2429f77 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -24,6 +24,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, + ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d493327a10..afc1341d09 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html @@ -235,7 +235,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +246,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +256,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( diff --git a/synapse/http/server.py b/synapse/http/server.py index e464bfe6c7..d69d579b3a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -168,11 +180,25 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] + + +class HttpServer(Protocol): """ Interface for registering callbacks on a HTTP server """ - def register_paths(self, method, path_patterns, callback): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -180,12 +206,14 @@ class HttpServer: an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index be938df962..0a561eea60 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet): return result +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + return e + + class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet): if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url + request, client_redirect_url, idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2672ce24c6..e2bb945453 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d["/_synapse/oidc"] = OIDCResource(self.hs) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/utils.py b/tests/utils.py index 09614093bc..022223cf24 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore @@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server -class MockHttpResource(HttpServer): +class MockHttpResource: def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix -- cgit 1.5.1 From 4b73488e811714089ba447884dccb9b6ae3ac16c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2021 17:39:21 +0000 Subject: Ratelimit 3PID /requestToken API (#9238) --- changelog.d/9238.feature | 1 + docs/sample_config.yaml | 6 +- synapse/config/_base.pyi | 2 +- synapse/config/ratelimiting.py | 13 ++++- synapse/handlers/identity.py | 28 ++++++++++ synapse/rest/client/v2_alpha/account.py | 12 +++- synapse/rest/client/v2_alpha/register.py | 6 ++ tests/rest/client/v2_alpha/test_account.py | 90 ++++++++++++++++++++++++++++-- tests/server.py | 9 ++- tests/unittest.py | 5 ++ tests/utils.py | 1 + 11 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 changelog.d/9238.feature (limited to 'tests') diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature new file mode 100644 index 0000000000..143a3e14f5 --- /dev/null +++ b/changelog.d/9238.feature @@ -0,0 +1 @@ +Add ratelimited to 3PID `/requestToken` API. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c2ccd68f3a..e5b6268087 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config" # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) +# - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # remote: # per_second: 0.01 # burst_count: 3 - +# +#rc_3pid_validation: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 7ed07a801d..70025b5d60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -54,7 +54,7 @@ class RootConfig: tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig - ratelimit: ratelimiting.RatelimitConfig + ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 14b8836197..76f382527d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -24,7 +24,7 @@ class RateLimitConfig: defaults={"per_second": 0.17, "burst_count": 3.0}, ): self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = config.get("burst_count", defaults["burst_count"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) class FederationRateLimitConfig: @@ -102,6 +102,11 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + self.rc_3pid_validation = RateLimitConfig( + config.get("rc_3pid_validation") or {}, + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -131,6 +136,7 @@ class RatelimitConfig(Config): # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) + # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -164,7 +170,10 @@ class RatelimitConfig(Config): # remote: # per_second: 0.01 # burst_count: 3 - + # + #rc_3pid_validation: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..4f7137539b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 65e68d641b..a84a2fb385 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b093183e79..10e1891174 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index cb87b80e33..177dc476da 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource @@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + def test_basic_password_reset_canonicalise_email(self): """Test basic password reset flow Request password reset with different spelling @@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret): + def _request_token(self, email, client_secret, ip="127.0.0.1"): channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, ) - self.assertEquals(200, channel.code, channel.result) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body["sid"] @@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_address_trim(self): self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP + """ + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed """ @@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): body["next_link"] = next_link channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) - self.assertEquals(expect_code, channel.code, channel.result) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body.get("sid") @@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _add_email(self, request_email, expected_email): """Test adding an email to profile """ + previous_email_attempts = len(self.email_attempts) + client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/server.py b/tests/server.py index 5a85d5fe7f..6419c445ec 100644 --- a/tests/server.py +++ b/tests/server.py @@ -47,6 +47,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) + _ip = attr.ib(type=str, default="127.0.0.1") _producer = None @property @@ -120,7 +121,7 @@ class FakeChannel: def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address("TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): return None @@ -196,6 +197,7 @@ def make_request( custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Make a web request using the given method, path and content, and render it @@ -223,6 +225,9 @@ def make_request( will pump the reactor until the the renderer tells the channel the request is finished. + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: channel """ @@ -250,7 +255,7 @@ def make_request( if isinstance(content, str): content = content.encode("utf8") - channel = FakeChannel(site, reactor) + channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel) req.content = BytesIO(content) diff --git a/tests/unittest.py b/tests/unittest.py index bbd295687c..767d5d6077 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase): custom_headers: (name, value) pairs to add as request headers + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: The FakeChannel object which stores the result of the request. """ @@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase): content_is_form, await_result, custom_headers, + client_ip, ) def setup_test_homeserver(self, *args, **kwargs): diff --git a/tests/utils.py b/tests/utils.py index 022223cf24..68033d7535 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -157,6 +157,7 @@ def default_config(name, parse=False): "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, -- cgit 1.5.1 From f2c1560eca1e2160087a280261ca78d0708ad721 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2021 16:38:29 +0000 Subject: Ratelimit invites by room and target user (#9258) --- changelog.d/9258.feature | 1 + docs/sample_config.yaml | 10 ++++ synapse/config/ratelimiting.py | 19 +++++++ synapse/federation/federation_client.py | 2 +- synapse/handlers/federation.py | 4 ++ synapse/handlers/room.py | 7 +++ synapse/handlers/room_member.py | 25 ++++++++- tests/handlers/test_federation.py | 93 ++++++++++++++++++++++++++++++++- tests/rest/client/v1/test_rooms.py | 35 +++++++++++++ 9 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 changelog.d/9258.feature (limited to 'tests') diff --git a/changelog.d/9258.feature b/changelog.d/9258.feature new file mode 100644 index 0000000000..0028f42d26 --- /dev/null +++ b/changelog.d/9258.feature @@ -0,0 +1 @@ +Add ratelimits to invites in rooms and to specific users. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 332befd948..7fd35516dc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -825,6 +825,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. +# - two for ratelimiting how often invites can be sent in a room or to a +# specific user. # # The defaults are as shown below. # @@ -862,6 +864,14 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_3pid_validation: # per_second: 0.003 # burst_count: 5 +# +#rc_invites: +# per_room: +# per_second: 0.3 +# burst_count: 10 +# per_user: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76f382527d..def33a60ad 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -107,6 +107,15 @@ class RatelimitConfig(Config): defaults={"per_second": 0.003, "burst_count": 5}, ) + self.rc_invites_per_room = RateLimitConfig( + config.get("rc_invites", {}).get("per_room", {}), + defaults={"per_second": 0.3, "burst_count": 10}, + ) + self.rc_invites_per_user = RateLimitConfig( + config.get("rc_invites", {}).get("per_user", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -137,6 +146,8 @@ class RatelimitConfig(Config): # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. + # - two for ratelimiting how often invites can be sent in a room or to a + # specific user. # # The defaults are as shown below. # @@ -174,6 +185,14 @@ class RatelimitConfig(Config): #rc_3pid_validation: # per_second: 0.003 # burst_count: 5 + # + #rc_invites: + # per_room: + # per_second: 0.3 + # burst_count: 10 + # per_user: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d330ae5dbc..40e1451201 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -810,7 +810,7 @@ class FederationClient(FederationBase): "User's homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - elif e.code == 403: + elif e.code in (403, 429): raise e.to_synapse_error() else: raise diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b6dc7f99b6..dbdfd56ff5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + member_handler.ratelimit_invite(event.room_id, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee27d99135..07b2187eb1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e001e418f9..d335da6f19 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: str, invitee_user_id: str): + """Ratelimit invites by room and by target user. + """ + self._invites_per_room_limiter.ratelimit(room_id) + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): block_invite = True if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), target_id, room_id ): logger.info("Blocking invite due to spam checker") block_invite = True diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 0b24b89a2e..74503112f5 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -16,7 +16,7 @@ import logging from unittest import TestCase from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json @@ -191,6 +191,97 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_room_ratelimit(self): + """Tests that invites from federation in a room are actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + def create_invite_for(local_user): + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": local_user, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + self.get_success( + self.handler.on_invite_request( + other_server, + create_invite_for("@user-%d:test" % (i,)), + room_version, + ) + ) + + self.get_failure( + self.handler.on_invite_request( + other_server, create_invite_for("@user-4:test"), room_version, + ), + exc=LimitExceededError, + ) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_user_ratelimit(self): + """Tests that invites from federation to a particular user are + actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + def create_invite(): + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": "@user:test", + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + event = create_invite() + self.get_success( + self.handler.on_invite_request(other_server, event, event.room_version,) + ) + + event = create_invite() + self.get_failure( + self.handler.on_invite_request(other_server, event, event.room_version,), + exc=LimitExceededError, + ) + def _build_and_send_join_event(self, other_server, other_user, room_id): join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d4e3165436..2548b3a80c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for i in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From f78d07bf005f7212bcc74256721677a3b255ea0e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 13:15:51 +0000 Subject: Split out a separate endpoint to complete SSO registration (#9262) There are going to be a couple of paths to get to the final step of SSO reg, and I want the URL in the browser to consistent. So, let's move the final step onto a separate path, which we redirect to. --- changelog.d/9262.feature | 1 + synapse/app/homeserver.py | 2 + synapse/handlers/sso.py | 81 ++++++++++++++++++++++------ synapse/http/server.py | 7 +++ synapse/rest/synapse/client/pick_username.py | 16 +++--- synapse/rest/synapse/client/sso_register.py | 50 +++++++++++++++++ tests/rest/client/v1/test_login.py | 14 ++++- 7 files changed, 145 insertions(+), 26 deletions(-) create mode 100644 changelog.d/9262.feature create mode 100644 synapse/rest/synapse/client/sso_register.py (limited to 'tests') diff --git a/changelog.d/9262.feature b/changelog.d/9262.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9262.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 57a2f5237c..86d6f73674 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,6 +62,7 @@ from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_synapse/admin": AdminRestResource(self), "/_synapse/client/pick_username": pick_username_resource(self), "/_synapse/client/pick_idp": PickIdpResource(self), + "/_synapse/client/sso_register": SsoRegisterResource(self), } ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 3308b037d2..50c5ae142a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -21,12 +21,13 @@ import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer @@ -141,6 +142,9 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -647,6 +651,25 @@ class SsoHandler: ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( self, localpart: str, session_id: str, ) -> bool: @@ -663,12 +686,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -696,16 +714,33 @@ class SsoHandler: localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + session = self.get_mapping_session(session_id) + + # update the session with the user's choices + session.chosen_localpart = localpart + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. - logger.info("[session %s] Registering localpart %s", session_id, localpart) + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + session = self.get_mapping_session(session_id) + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, + localpart=session.chosen_localpart, display_name=session.display_name, emails=session.emails, ) @@ -720,7 +755,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -751,3 +791,14 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") diff --git a/synapse/http/server.py b/synapse/http/server.py index d69d579b3a..8249732b27 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request): request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") +def respond_with_redirect(request: Request, url: bytes) -> None: + """Write a 302 response to the request, if it is still alive.""" + logger.debug("Redirect to %s", url.decode("utf-8")) + request.redirect(url) + finish_request(request) + + def finish_request(request: Request): """ Finish writing the response to the request. diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index d3b6803e65..1bc737bad0 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import TYPE_CHECKING import pkg_resources @@ -20,8 +21,7 @@ from twisted.web.http import Request from twisted.web.resource import Resource from twisted.web.static import File -from synapse.api.errors import SynapseError -from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -61,12 +61,10 @@ class AvailabilityCheckResource(DirectServeJsonResource): async def _async_render_GET(self, request: Request): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) is_available = await self._sso_handler.check_username_availability( - localpart, session_id.decode("ascii", errors="replace") + localpart, session_id ) return 200, {"available": is_available} @@ -79,10 +77,8 @@ class SubmitResource(DirectServeHtmlResource): async def _async_render_POST(self, request: SynapseRequest): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) await self._sso_handler.handle_submit_username_request( - request, localpart, session_id.decode("ascii", errors="replace") + request, localpart, session_id ) diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py new file mode 100644 index 0000000000..dfefeb7796 --- /dev/null +++ b/synapse/rest/synapse/client/sso_register.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SsoRegisterResource(DirectServeHtmlResource): + """A resource which completes SSO registration + + This resource gets mounted at /_synapse/client/sso_register, and is shown + after we collect username and/or consent for a new SSO user. It (finally) registers + the user, and confirms redirect to the client + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + await self._sso_handler.register_sso_user(request, session_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index e2bb945453..f01215ed1c 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.types import create_requester from tests import unittest @@ -1215,6 +1216,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d = super().create_resource_dict() d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) d["/_synapse/oidc"] = OIDCResource(self.hs) return d @@ -1253,7 +1255,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) # Now, submit a username to the username picker, which should serve a redirect - # back to the client + # to the completion page submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( @@ -1270,6 +1272,16 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location matches the requested redirect URL path, query = location_headers[0].split("?", 1) self.assertEqual(path, "https://x") -- cgit 1.5.1 From 9c715a5f1981891815c124353ba15cf4d17bf9bb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:47:59 +0000 Subject: Fix SSO on workers (#9271) Fixes #8966. * Factor out build_synapse_client_resource_tree Start a function which will mount resources common to all workers. * Move sso init into build_synapse_client_resource_tree ... so that we don't have to do it for each worker * Fix SSO-login-via-a-worker Expose the SSO login endpoints on workers, like the documentation says. * Update workers config for new endpoints Add documentation for endpoints recently added (#8942, #9017, #9262) * remove submit_token from workers endpoints list this *doesn't* work on workers (yet). * changelog * Add a comment about the odd path for SAML2Resource --- changelog.d/9271.bugfix | 1 + docs/workers.md | 18 +++++----- synapse/app/generic_worker.py | 11 +++--- synapse/app/homeserver.py | 18 ++-------- synapse/rest/synapse/client/__init__.py | 49 +++++++++++++++++++++++++- synapse/storage/databases/main/registration.py | 40 ++++++++++----------- tests/rest/client/v1/test_login.py | 15 ++------ tests/rest/client/v2_alpha/test_auth.py | 6 ++-- 8 files changed, 93 insertions(+), 65 deletions(-) create mode 100644 changelog.d/9271.bugfix (limited to 'tests') diff --git a/changelog.d/9271.bugfix b/changelog.d/9271.bugfix new file mode 100644 index 0000000000..ef30c6570f --- /dev/null +++ b/changelog.d/9271.bugfix @@ -0,0 +1 @@ +Fix single-sign-on when the endpoints are routed to synapse workers. diff --git a/docs/workers.md b/docs/workers.md index d01683681f..6b8887de36 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -225,7 +225,6 @@ expressions: ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ - ^/_synapse/client/password_reset/email/submit_token$ # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ @@ -256,25 +255,28 @@ Additionally, the following endpoints should be included if Synapse is configure to use SSO (you only need to include the ones for whichever SSO provider you're using): + # for all SSO providers + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect + ^/_synapse/client/pick_idp$ + ^/_synapse/client/pick_username + ^/_synapse/client/sso_register$ + # OpenID Connect requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_synapse/oidc/callback$ # SAML requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_matrix/saml2/authn_response$ # CAS requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ -Note that a HTTP listener with `client` and `federation` resources must be -configured in the `worker_listeners` option in the worker config. - -Ensure that all SSO logins go to a single process (usually the main process). +Ensure that all SSO logins go to a single process. For multiple workers not handling the SSO endpoints properly, see [#7530](https://github.com/matrix-org/synapse/issues/7530). +Note that a HTTP listener with `client` and `federation` resources must be +configured in the `worker_listeners` option in the worker config. + #### Load balancing It is possible to run multiple instances of this worker app, with incoming requests diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e60988fa4a..516f2464b4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -22,6 +22,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager from twisted.internet import address +from twisted.web.resource import IResource import synapse import synapse.events @@ -90,9 +91,8 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, ) from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, room +from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.profile import ( ProfileAvatarURLRestServlet, ProfileDisplaynameRestServlet, @@ -127,6 +127,7 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore @@ -507,7 +508,7 @@ class GenericWorkerServer(HomeServer): site_tag = port # We always include a health resource. - resources = {"/health": HealthResource()} + resources = {"/health": HealthResource()} # type: Dict[str, IResource] for res in listener_config.http_options.resources: for name in res.names: @@ -517,7 +518,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) - LoginRestServlet(self).register(resource) + login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) @@ -557,6 +558,8 @@ class GenericWorkerServer(HomeServer): groups.register_servlets(self, resource) resources.update({CLIENT_API_PREFIX: resource}) + + resources.update(build_synapse_client_resource_tree(self)) elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 86d6f73674..244657cb88 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -60,9 +60,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -191,22 +189,10 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), - "/_synapse/client/pick_username": pick_username_resource(self), - "/_synapse/client/pick_idp": PickIdpResource(self), - "/_synapse/client/sso_register": SsoRegisterResource(self), + **build_synapse_client_resource_tree(self), } ) - if self.get_config().oidc_enabled: - from synapse.rest.oidc import OIDCResource - - resources["/_synapse/oidc"] = OIDCResource(self) - - if self.get_config().saml2_enabled: - from synapse.rest.saml2 import SAML2Resource - - resources["/_matrix/saml2"] = SAML2Resource(self) - if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index c0b733488b..6acbc03d73 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import TYPE_CHECKING, Mapping + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]: + """Builds a resource tree to include synapse-specific client resources + + These are resources which should be loaded on all workers which expose a C-S API: + ie, the main process, and any generic workers so configured. + + Returns: + map from path to Resource. + """ + resources = { + # SSO bits. These are always loaded, whether or not SSO login is actually + # enabled (they just won't work very well if it's not) + "/_synapse/client/pick_idp": PickIdpResource(hs), + "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/sso_register": SsoRegisterResource(hs), + } + + # provider-specific SSO bits. Only load these if they are enabled, since they + # rely on optional dependencies. + if hs.config.oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(hs) + + if hs.config.saml2_enabled: + from synapse.rest.saml2 import SAML2Resource + + # This is mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = SAML2Resource(hs) + + return resources + + +__all__ = ["build_synapse_client_resource_tree"] diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8d05288ed4..14c0878d81 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -443,6 +443,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + async def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -1371,26 +1391,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - async def record_user_external_id( - self, auth_provider: str, external_id: str, user_id: str - ) -> None: - """Record a mapping from an external user id to a mxid - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - user_id: complete mxid that it is mapped to - """ - await self.db_pool.simple_insert( - table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, - desc="record_user_external_id", - ) - async def user_set_password_hash( self, user_id: str, password_hash: Optional[str] ) -> None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index f01215ed1c..ded22a9767 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,9 +29,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester from tests import unittest @@ -424,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_get_login_flows(self): @@ -1212,12 +1207,8 @@ class UsernamePickerTestCase(HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_username_picker(self): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a6488a3d29..3f50c56745 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -22,7 +22,7 @@ from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.rest.oidc import OIDCResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID from tests import unittest @@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def create_resource_dict(self): resource_dict = super().create_resource_dict() - if HAS_OIDC: - # mount the OIDC resource at /_synapse/oidc - resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From 8aed29dc615bee75019fc526a5c91cdc2638b665 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:50:56 +0000 Subject: Improve styling and wording of SSO redirect confirm template (#9272) --- changelog.d/9272.feature | 1 + docs/sample_config.yaml | 14 ++++- synapse/config/sso.py | 14 ++++- synapse/handlers/auth.py | 24 ++++++- synapse/handlers/sso.py | 10 ++- synapse/module_api/__init__.py | 10 ++- synapse/res/templates/sso.css | 83 +++++++++++++++++++++++++ synapse/res/templates/sso_redirect_confirm.html | 34 ++++++++-- tests/handlers/test_cas.py | 8 +-- tests/handlers/test_oidc.py | 24 ++++--- tests/handlers/test_saml.py | 8 +-- 11 files changed, 200 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9272.feature create mode 100644 synapse/res/templates/sso.css (limited to 'tests') diff --git a/changelog.d/9272.feature b/changelog.d/9272.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9272.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 8777e3254d..05506a7787 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1971,7 +1971,8 @@ sso: # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # - # When rendering, this template is given three variables: + # When rendering, this template is given the following variables: + # # * redirect_url: the URL the user is about to be redirected to. Needs # manual escaping (see # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). @@ -1984,6 +1985,17 @@ sso: # # * server_name: the homeserver's name. # + # * new_user: a boolean indicating whether this is the user's first time + # logging in. + # + # * user_id: the user's matrix ID. + # + # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. + # None if the user has not set an avatar. + # + # * user_profile.display_name: the user's display name. None if the user + # has not set a display name. + # # * HTML page which notifies the user that they are authenticating to confirm # an operation on their account during the user interactive authentication # process: 'sso_auth_confirm.html'. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 59be825532..a470112ed4 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -127,7 +127,8 @@ class SSOConfig(Config): # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # - # When rendering, this template is given three variables: + # When rendering, this template is given the following variables: + # # * redirect_url: the URL the user is about to be redirected to. Needs # manual escaping (see # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). @@ -140,6 +141,17 @@ class SSOConfig(Config): # # * server_name: the homeserver's name. # + # * new_user: a boolean indicating whether this is the user's first time + # logging in. + # + # * user_id: the user's matrix ID. + # + # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. + # None if the user has not set an avatar. + # + # * user_profile.display_name: the user's display name. None if the user + # has not set a display name. + # # * HTML page which notifies the user that they are authenticating to confirm # an operation on their account during the user interactive authentication # process: 'sso_auth_confirm.html'. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0e98db22b3..c722a4afa8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi +from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable @@ -1396,6 +1397,7 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1406,6 +1408,8 @@ class AuthHandler(BaseHandler): process. extra_attributes: Extra attributes which will be passed to the client during successful login. Must be JSON serializable. + new_user: True if we should use wording appropriate to a user who has just + registered. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1414,8 +1418,17 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return + profile = await self.store.get_profileinfo( + UserID.from_string(registered_user_id).localpart + ) + self._complete_sso_login( - registered_user_id, request, client_redirect_url, extra_attributes + registered_user_id, + request, + client_redirect_url, + extra_attributes, + new_user=new_user, + user_profile_data=profile, ) def _complete_sso_login( @@ -1424,12 +1437,18 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, + user_profile_data: Optional[ProfileInfo] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + + if user_profile_data is None: + user_profile_data = ProfileInfo(None, None) + # Store any extra attributes which will be passed in the login response. # Note that this is per-user so it may overwrite a previous value, this # is considered OK since the newest SSO attributes should be most valid. @@ -1467,6 +1486,9 @@ class AuthHandler(BaseHandler): display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, + new_user=new_user, + user_id=registered_user_id, + user_profile=user_profile_data, ) respond_with_html(request, 200, html) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 50c5ae142a..ceaeb5a376 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -391,6 +391,8 @@ class SsoHandler: to an additional page. (e.g. to prompt for more information) """ + new_user = False + # grab a lock while we try to find a mapping for this user. This seems... # optimistic, especially for implementations that end up redirecting to # interstitial pages. @@ -431,9 +433,14 @@ class SsoHandler: get_request_user_agent(request), request.getClientIP(), ) + new_user = True await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_login_attributes + user_id, + request, + client_redirect_url, + extra_login_attributes, + new_user=new_user, ) async def _call_attribute_mapper( @@ -778,6 +785,7 @@ class SsoHandler: request, session.client_redirect_url, session.extra_login_attributes, + new_user=True, ) def _expire_old_sessions(self): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 72ab5750cc..401d577293 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -279,7 +279,11 @@ class ModuleApi: ) async def complete_sso_login_async( - self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + new_user: bool = False, ): """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that @@ -291,9 +295,11 @@ class ModuleApi: request: The request to respond to. client_redirect_url: The URL to which to offer to redirect the user (or to redirect them directly if whitelisted). + new_user: set to true to use wording for the consent appropriate to a user + who has just registered. """ await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url, + registered_user_id, request, client_redirect_url, new_user=new_user ) @defer.inlineCallbacks diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css new file mode 100644 index 0000000000..ff9dc94032 --- /dev/null +++ b/synapse/res/templates/sso.css @@ -0,0 +1,83 @@ +body { + font-family: "Inter", "Helvetica", "Arial", sans-serif; + font-size: 14px; + color: #17191C; +} + +header { + max-width: 480px; + width: 100%; + margin: 24px auto; + text-align: center; +} + +header p { + color: #737D8C; + line-height: 24px; +} + +h1 { + font-size: 24px; +} + +h2 { + font-size: 14px; +} + +h2 img { + vertical-align: middle; + margin-right: 8px; + width: 24px; + height: 24px; +} + +label { + cursor: pointer; +} + +main { + max-width: 360px; + width: 100%; + margin: 24px auto; +} + +.primary-button { + border: none; + text-decoration: none; + padding: 12px; + color: white; + background-color: #418DED; + font-weight: bold; + display: block; + border-radius: 12px; + width: 100%; + margin: 16px 0; + cursor: pointer; + text-align: center; +} + +.profile { + display: flex; + justify-content: center; + margin: 24px 0; +} + +.profile .avatar { + width: 36px; + height: 36px; + border-radius: 100%; + display: block; + margin-right: 8px; +} + +.profile .display-name { + font-weight: bold; + margin-bottom: 4px; +} +.profile .user-id { + color: #737D8C; +} + +.profile .display-name, .profile .user-id { + line-height: 18px; +} \ No newline at end of file diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index 20a15e1e74..ce4f573848 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -3,12 +3,34 @@ SSO redirect confirmation + + -

    The application at {{ display_url | e }} is requesting full access to your {{ server_name }} Matrix account.

    -

    If you don't recognise this address, you should ignore this and close this tab.

    -

    - I trust this address -

    +
    + {% if new_user %} +

    Your account is now ready

    +

    You've made your account on {{ server_name | e }}.

    + {% else %} +

    Log in

    + {% endif %} +

    Continue to confirm you trust {{ display_url | e }}.

    +
    +
    + {% if user_profile.avatar_url %} +
    + +
    + {% if user_profile.display_name %} +
    {{ user_profile.display_name | e }}
    + {% endif %} +
    {{ user_id | e }}
    +
    +
    + {% endif %} + Continue +
    - \ No newline at end of file + diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index c37bb6440e..7baf224f7e 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -62,7 +62,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=True ) def test_map_cas_user_to_existing_user(self): @@ -85,7 +85,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -94,7 +94,7 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=False ) def test_map_cas_user_to_invalid_localpart(self): @@ -112,7 +112,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index b3dfa40d25..d8f90b9a80 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -419,7 +419,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -450,7 +450,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() @@ -623,7 +623,11 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - "@foo:test", request, client_redirect_url, {"phone": "1234567"}, + "@foo:test", + request, + client_redirect_url, + {"phone": "1234567"}, + new_user=True, ) def test_map_userinfo_to_user(self): @@ -637,7 +641,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, None, + "@test_user:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -648,7 +652,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, None, + "@test_user_2:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -685,14 +689,14 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -707,7 +711,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -743,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, None, + "@TEST_USER_2:test", ANY, ANY, None, new_user=False ) def test_map_userinfo_to_invalid_localpart(self): @@ -779,7 +783,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, None, + "@test_user1:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 261c7083d1..a8d6c0f617 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=True ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None + "@test_user:test", request, "", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None + "@test_user:test", request, "", None, new_user=False ) def test_map_saml_response_to_invalid_localpart(self): @@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", request, "", None + "@test_user1:test", request, "", None, new_user=True ) auth_handler.complete_sso_login.reset_mock() -- cgit 1.5.1 From 4167494c90bc0477bdf4855a79e81dc81bba1377 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:52:50 +0000 Subject: Replace username picker with a template (#9275) There's some prelimiary work here to pull out the construction of a jinja environment to a separate function. I wanted to load the template at display time rather than load time, so that it's easy to update on the fly. Honestly, I think we should do this with all our templates: the risk of ending up with malformed templates is far outweighed by the improved turnaround time for an admin trying to update them. --- changelog.d/9275.feature | 1 + docs/sample_config.yaml | 32 +++++- synapse/config/_base.py | 39 +------ synapse/config/oidc_config.py | 3 +- synapse/config/sso.py | 33 +++++- synapse/handlers/sso.py | 2 +- .../res/templates/sso_auth_account_details.html | 115 +++++++++++++++++++++ synapse/res/templates/sso_auth_account_details.js | 76 ++++++++++++++ synapse/res/username_picker/index.html | 19 ---- synapse/res/username_picker/script.js | 95 ----------------- synapse/res/username_picker/style.css | 27 ----- synapse/rest/consent/consent_resource.py | 1 + synapse/rest/synapse/client/pick_username.py | 79 ++++++++++---- synapse/util/templates.py | 106 +++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- 15 files changed, 429 insertions(+), 204 deletions(-) create mode 100644 changelog.d/9275.feature create mode 100644 synapse/res/templates/sso_auth_account_details.html create mode 100644 synapse/res/templates/sso_auth_account_details.js delete mode 100644 synapse/res/username_picker/index.html delete mode 100644 synapse/res/username_picker/script.js delete mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/util/templates.py (limited to 'tests') diff --git a/changelog.d/9275.feature b/changelog.d/9275.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9275.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 05506a7787..a6fbcc6080 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1801,7 +1801,8 @@ saml2_config: # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their -# own username. +# own username (see 'sso_auth_account_details.html' in the 'sso' +# section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. @@ -1968,6 +1969,35 @@ sso: # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 94144efc87..35e5594b73 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,18 +18,18 @@ import argparse import errno import os -import time -import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional import attr import jinja2 import pkg_resources import yaml +from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter + class ConfigError(Exception): """Represents a problem parsing the configuration @@ -248,6 +248,7 @@ class Config: # Search the custom template directory as well search_directories.insert(0, custom_template_directory) + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) env = jinja2.Environment(loader=loader, autoescape=autoescape) @@ -267,38 +268,6 @@ class Config: return templates -def _format_ts_filter(value: int, format: str): - return time.strftime(format, time.localtime(value / 1000)) - - -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: - """Create and return a jinja2 filter that converts MXC urls to HTTP - - Args: - public_baseurl: The public, accessible base URL of the homeserver - """ - - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - server_and_media_id = value[6:] - fragment = None - if "#" in server_and_media_id: - server_and_media_id, fragment = server_and_media_id.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - server_and_media_id, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter - - class RootConfig: """ Holder of an application's configuration. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index f31511e039..784b416f95 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -152,7 +152,8 @@ class OIDCConfig(Config): # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their - # own username. + # own username (see 'sso_auth_account_details.html' in the 'sso' + # section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index a470112ed4..e308fc9333 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -27,7 +27,7 @@ class SSOConfig(Config): sso_config = config.get("sso") or {} # type: Dict[str, Any] # The sso-specific template_dir - template_dir = sso_config.get("template_dir") + self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk ( @@ -48,7 +48,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - template_dir, + self.sso_template_dir, ) # These templates have no placeholders, so render them here @@ -124,6 +124,35 @@ class SSOConfig(Config): # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ceaeb5a376..ff4750999a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -530,7 +530,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html new file mode 100644 index 0000000000..f22b09aec1 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.html @@ -0,0 +1,115 @@ + + + + Synapse Login + + + + + +
    +

    Your account is nearly ready

    +

    Check your details before creating an account on {{ server_name }}

    +
    +
    +
    +
    + +
    @
    + +
    :{{ server_name }}
    +
    + + {% if user_attributes %} +
    +

    Information from {{ idp.idp_name }}

    + {% if user_attributes.avatar_url %} +
    + +
    + {% endif %} + {% if user_attributes.display_name %} +
    +

    {{ user_attributes.display_name }}

    +
    + {% endif %} + {% for email in user_attributes.emails %} +
    +

    {{ email }}

    +
    + {% endfor %} +
    + {% endif %} +
    +
    + + + diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js new file mode 100644 index 0000000000..deef419bb6 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.js @@ -0,0 +1,76 @@ +const usernameField = document.getElementById("field-username"); + +function throttle(fn, wait) { + let timeout; + return function() { + const args = Array.from(arguments); + if (timeout) { + clearTimeout(timeout); + } + timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait); + } +} + +function checkUsernameAvailable(username) { + let check_uri = 'check?username=' + encodeURIComponent(username); + return fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw new Error(text); }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + return {message: json.error}; + } else if(json.available) { + return {available: true}; + } else { + return {message: username + " is not available, please choose another."}; + } + }); +} + +function validateUsername(username) { + usernameField.setCustomValidity(""); + if (usernameField.validity.valueMissing) { + usernameField.setCustomValidity("Please provide a username"); + return; + } + if (usernameField.validity.patternMismatch) { + usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString); + return; + } + usernameField.setCustomValidity("Checking if username is available …"); + throttledCheckUsernameAvailable(username); +} + +const throttledCheckUsernameAvailable = throttle(function(username) { + const handleError = function(err) { + // don't prevent form submission on error + usernameField.setCustomValidity(""); + console.log(err.message); + }; + try { + checkUsernameAvailable(username).then(function(result) { + if (!result.available) { + usernameField.setCustomValidity(result.message); + usernameField.reportValidity(); + } else { + usernameField.setCustomValidity(""); + } + }, handleError); + } catch (err) { + handleError(err); + } +}, 500); + +usernameField.addEventListener("input", function(evt) { + validateUsername(usernameField.value); +}); +usernameField.addEventListener("change", function(evt) { + validateUsername(usernameField.value); +}); diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html deleted file mode 100644 index 37ea8bb6d8..0000000000 --- a/synapse/res/username_picker/index.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - Synapse Login - - - -
    -
    - - - -
    - - - -
    - - diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js deleted file mode 100644 index 416a7c6f41..0000000000 --- a/synapse/res/username_picker/script.js +++ /dev/null @@ -1,95 +0,0 @@ -let inputField = document.getElementById("field-username"); -let inputForm = document.getElementById("form"); -let submitButton = document.getElementById("button-submit"); -let message = document.getElementById("message"); - -// Submit username and receive response -function showMessage(messageText) { - // Unhide the message text - message.classList.remove("hidden"); - - message.textContent = messageText; -}; - -function doSubmit() { - showMessage("Success. Please wait a moment for your browser to redirect."); - - // remove the event handler before re-submitting the form. - delete inputForm.onsubmit; - inputForm.submit(); -} - -function onResponse(response) { - // Display message - showMessage(response); - - // Enable submit button and input field - submitButton.classList.remove('button--disabled'); - submitButton.value = "Submit"; -}; - -let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); -function usernameIsValid(username) { - return !allowedUsernameCharacters.test(username); -} -let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; - -function buildQueryString(params) { - return Object.keys(params) - .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) - .join('&'); -} - -function submitUsername(username) { - if(username.length == 0) { - onResponse("Please enter a username."); - return; - } - if(!usernameIsValid(username)) { - onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); - return; - } - - // if this browser doesn't support fetch, skip the availability check. - if(!window.fetch) { - doSubmit(); - return; - } - - let check_uri = 'check?' + buildQueryString({"username": username}); - fetch(check_uri, { - // include the cookie - "credentials": "same-origin", - }).then((response) => { - if(!response.ok) { - // for non-200 responses, raise the body of the response as an exception - return response.text().then((text) => { throw text; }); - } else { - return response.json(); - } - }).then((json) => { - if(json.error) { - throw json.error; - } else if(json.available) { - doSubmit(); - } else { - onResponse("This username is not available, please choose another."); - } - }).catch((err) => { - onResponse("Error checking username availability: " + err); - }); -} - -function clickSubmit() { - event.preventDefault(); - if(submitButton.classList.contains('button--disabled')) { return; } - - // Disable submit button and input field - submitButton.classList.add('button--disabled'); - - // Submit username - submitButton.value = "Checking..."; - submitUsername(inputField.value); -}; - -inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css deleted file mode 100644 index 745bd4c684..0000000000 --- a/synapse/res/username_picker/style.css +++ /dev/null @@ -1,27 +0,0 @@ -input[type="text"] { - font-size: 100%; - background-color: #ededf0; - border: 1px solid #fff; - border-radius: .2em; - padding: .5em .9em; - display: block; - width: 26em; -} - -.button--disabled { - border-color: #fff; - background-color: transparent; - color: #000; - text-transform: none; -} - -.hidden { - display: none; -} - -.tooltip { - background-color: #f9f9fa; - padding: 1em; - margin: 1em 0; -} - diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b3e4d5612e..8b9ef26cf2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource): consent_template_directory = hs.config.user_consent_template_dir + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 1bc737bad0..27540d3bbe 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -13,41 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import TYPE_CHECKING -import pkg_resources - from twisted.web.http import Request from twisted.web.resource import Resource -from twisted.web.static import File +from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request -from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest +from synapse.util.templates import build_jinja_env if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + def pick_username_resource(hs: "HomeServer") -> Resource: """Factory method to generate the username picker resource. - This resource gets mounted under /_synapse/client/pick_username. The top-level - resource is just a File resource which serves up the static files in the resources - "res" directory, but it has a couple of children: + This resource gets mounted under /_synapse/client/pick_username and has two + children: - * "submit", which does the mechanics of registering the new user, and redirects the - browser back to the client URL - - * "check": checks if a userid is free. + * "account_details": renders the form and handles the POSTed response + * "check": a JSON endpoint which checks if a userid is free. """ - # XXX should we make this path customisable so that admins can restyle it? - base_path = pkg_resources.resource_filename("synapse", "res/username_picker") - - res = File(base_path) - res.putChild(b"submit", SubmitResource(hs)) + res = Resource() + res.putChild(b"account_details", AccountDetailsResource(hs)) res.putChild(b"check", AvailabilityCheckResource(hs)) return res @@ -69,15 +69,54 @@ class AvailabilityCheckResource(DirectServeJsonResource): return 200, {"available": is_available} -class SubmitResource(DirectServeHtmlResource): +class AccountDetailsResource(DirectServeHtmlResource): def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - async def _async_render_POST(self, request: SynapseRequest): - localpart = parse_string(request, "username", required=True) + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + idp_id = session.auth_provider_id + template_params = { + "idp": self._sso_handler.get_identity_providers()[idp_id], + "user_attributes": { + "display_name": session.display_name, + "emails": session.emails, + }, + } + + template = self._jinja_env.get_template("sso_auth_account_details.html") + html = template.render(template_params) + respond_with_html(request, 200, html) - session_id = get_username_mapping_session_cookie_from_request(request) + async def _async_render_POST(self, request: SynapseRequest): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + localpart = parse_string(request, "username", required=True) + except SynapseError as e: + logger.warning("[session %s] bad param: %s", session_id, e) + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return await self._sso_handler.handle_submit_username_request( request, localpart, session_id diff --git a/synapse/util/templates.py b/synapse/util/templates.py new file mode 100644 index 0000000000..7e5109d206 --- /dev/null +++ b/synapse/util/templates.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with jinja2 templates""" + +import time +import urllib.parse +from typing import TYPE_CHECKING, Callable, Iterable, Union + +import jinja2 + +if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig + + +def build_jinja_env( + template_search_directories: Iterable[str], + config: "HomeServerConfig", + autoescape: Union[bool, Callable[[str], bool], None] = None, +) -> jinja2.Environment: + """Set up a Jinja2 environment to load templates from the given search path + + The returned environment defines the following filters: + - format_ts: formats timestamps as strings in the server's local timezone + (XXX: why is that useful??) + - mxc_to_http: converts mxc: uris to http URIs. Args are: + (uri, width, height, resize_method="crop") + + and the following global variables: + - server_name: matrix server name + + Args: + template_search_directories: directories to search for templates + + config: homeserver config, for things like `server_name` and `public_baseurl` + + autoescape: whether template variables should be autoescaped. bool, or + a function mapping from template name to bool. Defaults to escaping templates + whose names end in .html, .xml or .htm. + + Returns: + jinja environment + """ + + if autoescape is None: + autoescape = jinja2.select_autoescape() + + loader = jinja2.FileSystemLoader(template_search_directories) + env = jinja2.Environment(loader=loader, autoescape=autoescape) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + } + ) + + # common variables for all templates + env.globals.update({"server_name": config.server_name}) + + return env + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index ded22a9767..66dfdaffbc 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1222,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) picker_url = channel.headers.getRawHeaders("Location")[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username") + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie cookies = {} # type: Dict[str,str] @@ -1247,11 +1247,10 @@ class UsernamePickerTestCase(HomeserverTestCase): # Now, submit a username to the username picker, which should serve a redirect # to the completion page - submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", - path=submit_path, + path=picker_url, content=content, content_is_form=True, custom_headers=[ -- cgit 1.5.1 From 5d38a3c97f3de4963584fabbd6dc82d55129ecdf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 1 Feb 2021 13:09:39 -0500 Subject: Refactor email summary generation. (#9260) * Fixes a case where no summary text was returned. * The use of messages_from_person vs. messages_from_person_and_others was tweaked to depend on whether there was 1 sender or multiple senders, not based on if there was 1 room or multiple rooms. --- changelog.d/9260.misc | 1 + synapse/push/mailer.py | 295 +++++++++++++++++++++++++++-------------------- tests/push/test_email.py | 30 +++++ 3 files changed, 204 insertions(+), 122 deletions(-) create mode 100644 changelog.d/9260.misc (limited to 'tests') diff --git a/changelog.d/9260.misc b/changelog.d/9260.misc new file mode 100644 index 0000000000..0150e10ea9 --- /dev/null +++ b/changelog.d/9260.misc @@ -0,0 +1 @@ +Refactor the generation of summary text for email notifications. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 745b1dde94..8a6dcff30d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -267,9 +267,21 @@ class Mailer: fallback_to_members=True, ) - summary_text = await self.make_summary_text( - notifs_by_room, state_by_room, notif_events, user_id, reason - ) + if len(notifs_by_room) == 1: + # Only one room has new stuff + room_id = list(notifs_by_room.keys())[0] + + summary_text = await self.make_summary_text_single_room( + room_id, + notifs_by_room[room_id], + state_by_room[room_id], + notif_events, + user_id, + ) + else: + summary_text = await self.make_summary_text( + notifs_by_room, state_by_room, notif_events, reason + ) template_vars = { "user_display_name": user_display_name, @@ -492,139 +504,178 @@ class Mailer: if "url" in event.content: messagevars["image_url"] = event.content["url"] - async def make_summary_text( + async def make_summary_text_single_room( self, - notifs_by_room: Dict[str, List[Dict[str, Any]]], - room_state_ids: Dict[str, StateMap[str]], + room_id: str, + notifs: List[Dict[str, Any]], + room_state_ids: StateMap[str], notif_events: Dict[str, EventBase], user_id: str, - reason: Dict[str, Any], - ): - if len(notifs_by_room) == 1: - # Only one room has new stuff - room_id = list(notifs_by_room.keys())[0] + ) -> str: + """ + Make a summary text for the email when only a single room has notifications. - # If the room has some kind of name, use it, but we don't - # want the generated-from-names one here otherwise we'll - # end up with, "new message from Bob in the Bob room" - room_name = await calculate_room_name( - self.store, room_state_ids[room_id], user_id, fallback_to_members=False - ) + Args: + room_id: The ID of the room. + notifs: The notifications for this room. + room_state_ids: The state map for the room. + notif_events: A map of event ID -> notification event. + user_id: The user receiving the notification. + + Returns: + The summary text. + """ + # If the room has some kind of name, use it, but we don't + # want the generated-from-names one here otherwise we'll + # end up with, "new message from Bob in the Bob room" + room_name = await calculate_room_name( + self.store, room_state_ids, user_id, fallback_to_members=False + ) - # See if one of the notifs is an invite event for the user - invite_event = None - for n in notifs_by_room[room_id]: - ev = notif_events[n["event_id"]] - if ev.type == EventTypes.Member and ev.state_key == user_id: - if ev.content.get("membership") == Membership.INVITE: - invite_event = ev - break - - if invite_event: - inviter_member_event_id = room_state_ids[room_id].get( - ("m.room.member", invite_event.sender) - ) - inviter_name = invite_event.sender - if inviter_member_event_id: - inviter_member_event = await self.store.get_event( - inviter_member_event_id, allow_none=True - ) - if inviter_member_event: - inviter_name = name_from_member_event(inviter_member_event) - - if room_name is None: - return self.email_subjects.invite_from_person % { - "person": inviter_name, - "app": self.app_name, - } - else: - return self.email_subjects.invite_from_person_to_room % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - } + # See if one of the notifs is an invite event for the user + invite_event = None + for n in notifs: + ev = notif_events[n["event_id"]] + if ev.type == EventTypes.Member and ev.state_key == user_id: + if ev.content.get("membership") == Membership.INVITE: + invite_event = ev + break - sender_name = None - if len(notifs_by_room[room_id]) == 1: - # There is just the one notification, so give some detail - event = notif_events[notifs_by_room[room_id][0]["event_id"]] - if ("m.room.member", event.sender) in room_state_ids[room_id]: - state_event_id = room_state_ids[room_id][ - ("m.room.member", event.sender) - ] - state_event = await self.store.get_event(state_event_id) - sender_name = name_from_member_event(state_event) - - if sender_name is not None and room_name is not None: - return self.email_subjects.message_from_person_in_room % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - } - elif sender_name is not None: - return self.email_subjects.message_from_person % { - "person": sender_name, - "app": self.app_name, - } - else: - # There's more than one notification for this room, so just - # say there are several - if room_name is not None: - return self.email_subjects.messages_in_room % { - "room": room_name, - "app": self.app_name, - } - else: - # If the room doesn't have a name, say who the messages - # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list( - { - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - } - ) - - member_events = await self.store.get_events( - [ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ] - ) - - return self.email_subjects.messages_from_person % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - } - else: - # Stuff's happened in multiple different rooms + if invite_event: + inviter_member_event_id = room_state_ids.get( + ("m.room.member", invite_event.sender) + ) + inviter_name = invite_event.sender + if inviter_member_event_id: + inviter_member_event = await self.store.get_event( + inviter_member_event_id, allow_none=True + ) + if inviter_member_event: + inviter_name = name_from_member_event(inviter_member_event) - # ...but we still refer to the 'reason' room which triggered the mail - if reason["room_name"] is not None: - return self.email_subjects.messages_in_room_and_others % { - "room": reason["room_name"], + if room_name is None: + return self.email_subjects.invite_from_person % { + "person": inviter_name, "app": self.app_name, } - else: - # If the reason room doesn't have a name, say who the messages - # are from explicitly to avoid, "messages in the Bob room" - room_id = reason["room_id"] - - sender_ids = list( - { - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - } - ) - member_events = await self.store.get_events( - [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] - ) + return self.email_subjects.invite_from_person_to_room % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + } + + if len(notifs) == 1: + # There is just the one notification, so give some detail + sender_name = None + event = notif_events[notifs[0]["event_id"]] + if ("m.room.member", event.sender) in room_state_ids: + state_event_id = room_state_ids[("m.room.member", event.sender)] + state_event = await self.store.get_event(state_event_id) + sender_name = name_from_member_event(state_event) + + if sender_name is not None and room_name is not None: + return self.email_subjects.message_from_person_in_room % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + } + elif sender_name is not None: + return self.email_subjects.message_from_person % { + "person": sender_name, + "app": self.app_name, + } - return self.email_subjects.messages_from_person_and_others % { - "person": descriptor_from_member_events(member_events.values()), + # The sender is unknown, just use the room name (or ID). + return self.email_subjects.messages_in_room % { + "room": room_name or room_id, + "app": self.app_name, + } + else: + # There's more than one notification for this room, so just + # say there are several + if room_name is not None: + return self.email_subjects.messages_in_room % { + "room": room_name, "app": self.app_name, } + return await self.make_summary_text_from_member_events( + room_id, notifs, room_state_ids, notif_events + ) + + async def make_summary_text( + self, + notifs_by_room: Dict[str, List[Dict[str, Any]]], + room_state_ids: Dict[str, StateMap[str]], + notif_events: Dict[str, EventBase], + reason: Dict[str, Any], + ) -> str: + """ + Make a summary text for the email when multiple rooms have notifications. + + Args: + notifs_by_room: A map of room ID to the notifications for that room. + room_state_ids: A map of room ID to the state map for that room. + notif_events: A map of event ID -> notification event. + reason: The reason this notification is being sent. + + Returns: + The summary text. + """ + # Stuff's happened in multiple different rooms + # ...but we still refer to the 'reason' room which triggered the mail + if reason["room_name"] is not None: + return self.email_subjects.messages_in_room_and_others % { + "room": reason["room_name"], + "app": self.app_name, + } + + room_id = reason["room_id"] + return await self.make_summary_text_from_member_events( + room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events + ) + + async def make_summary_text_from_member_events( + self, + room_id: str, + notifs: List[Dict[str, Any]], + room_state_ids: StateMap[str], + notif_events: Dict[str, EventBase], + ) -> str: + """ + Make a summary text for the email when only a single room has notifications. + + Args: + room_id: The ID of the room. + notifs: The notifications for this room. + room_state_ids: The state map for the room. + notif_events: A map of event ID -> notification event. + + Returns: + The summary text. + """ + # If the room doesn't have a name, say who the messages + # are from explicitly to avoid, "messages in the Bob room" + sender_ids = {notif_events[n["event_id"]].sender for n in notifs} + + member_events = await self.store.get_events( + [room_state_ids[("m.room.member", s)] for s in sender_ids] + ) + + # There was a single sender. + if len(sender_ids) == 1: + return self.email_subjects.messages_from_person % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } + + # There was more than one sender, use the first one and a tweaked template. + return self.email_subjects.messages_from_person_and_others % { + "person": descriptor_from_member_events(list(member_events.values())[:1]), + "app": self.app_name, + } + def make_room_link(self, room_id: str) -> str: if self.hs.config.email_riot_base_url: base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 961bf09de9..c4e1e7ed85 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -187,6 +187,36 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about those messages self._check_for_mail() + def test_multiple_rooms(self): + # We want to test multiple notifications from multiple rooms, so we pause + # processing of push while we send messages. + self.pusher._pause_processing() + + # Create a simple room with multiple other users + rooms = [ + self.helper.create_room_as(self.user_id, tok=self.access_token), + self.helper.create_room_as(self.user_id, tok=self.access_token), + ] + + for r, other in zip(rooms, self.others): + self.helper.invite( + room=r, src=self.user_id, tok=self.access_token, targ=other.id + ) + self.helper.join(room=r, user=other.id, tok=other.token) + + # The other users send some messages + self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token) + self.helper.send(rooms[1], body="There!", tok=self.others[1].token) + self.helper.send(rooms[1], body="There!", tok=self.others[1].token) + + # Nothing should have happened yet, as we're paused. + assert not self.email_attempts + + self.pusher._resume_processing() + + # We should get emailed about those messages + self._check_for_mail() + def test_encrypted_message(self): room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( -- cgit 1.5.1 From 846b9d3df033be1043710e49e89bcba68722071e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 22:56:01 +0000 Subject: Put OIDC callback URI under /_synapse/client. (#9288) --- CHANGES.md | 4 +++ UPGRADE.rst | 13 ++++++++- changelog.d/9288.feature | 1 + docs/openid.md | 19 ++++++------- docs/workers.md | 2 +- synapse/config/oidc_config.py | 2 +- synapse/handlers/oidc_handler.py | 8 +++--- synapse/rest/oidc/__init__.py | 27 ------------------- synapse/rest/oidc/callback_resource.py | 30 --------------------- synapse/rest/synapse/client/__init__.py | 4 +-- synapse/rest/synapse/client/oidc/__init__.py | 31 ++++++++++++++++++++++ .../rest/synapse/client/oidc/callback_resource.py | 30 +++++++++++++++++++++ tests/handlers/test_oidc.py | 15 +++++------ 13 files changed, 102 insertions(+), 84 deletions(-) create mode 100644 changelog.d/9288.feature delete mode 100644 synapse/rest/oidc/__init__.py delete mode 100644 synapse/rest/oidc/callback_resource.py create mode 100644 synapse/rest/synapse/client/oidc/__init__.py create mode 100644 synapse/rest/synapse/client/oidc/callback_resource.py (limited to 'tests') diff --git a/CHANGES.md b/CHANGES.md index fcd782fa94..e9ff14a03d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,10 @@ Unreleased Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically. +This release also changes the callback URI for OpenID Connect (OIDC) identity +providers. If your server is configured to use single sign-on via an +OIDC/OAuth2 IdP, you may need to make configuration changes. Please review +[UPGRADE.rst](UPGRADE.rst) for more details on these changes. Synapse 1.26.0 (2021-01-27) =========================== diff --git a/UPGRADE.rst b/UPGRADE.rst index eea0322695..d00f718cae 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -88,6 +88,17 @@ for example: Upgrading to v1.27.0 ==================== +Changes to callback URI for OAuth2 / OpenID Connect +--------------------------------------------------- + +This version changes the URI used for callbacks from OAuth2 identity providers. If +your server is configured for single sign-on via an OpenID Connect or OAuth2 identity +provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback`` +to the list of permitted "redirect URIs" at the identity provider. + +See `docs/openid.md `_ for more information on setting up OpenID +Connect. + Changes to HTML templates ------------------------- @@ -235,7 +246,7 @@ shown below: return {"localpart": localpart} -Removal historical Synapse Admin API +Removal historical Synapse Admin API ------------------------------------ Historically, the Synapse Admin API has been accessible under: diff --git a/changelog.d/9288.feature b/changelog.d/9288.feature new file mode 100644 index 0000000000..efde69fb3c --- /dev/null +++ b/changelog.d/9288.feature @@ -0,0 +1 @@ +Update the redirect URI for OIDC authentication. diff --git a/docs/openid.md b/docs/openid.md index 3d07220967..9d19368845 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -54,7 +54,8 @@ Here are a few configs for providers that should work with Synapse. ### Microsoft Azure Active Directory Azure AD can act as an OpenID Connect Provider. Register a new application under *App registrations* in the Azure AD management console. The RedirectURI for your -application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback` +application should point to your matrix server: +`[synapse public baseurl]/_synapse/client/oidc/callback` Go to *Certificates & secrets* and register a new client secret. Make note of your Directory (tenant) ID as it will be used in the Azure links. @@ -94,7 +95,7 @@ staticClients: - id: synapse secret: secret redirectURIs: - - '[synapse public baseurl]/_synapse/oidc/callback' + - '[synapse public baseurl]/_synapse/client/oidc/callback' name: 'Synapse' ``` @@ -140,7 +141,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Enabled | `On` | | Client Protocol | `openid-connect` | | Access Type | `confidential` | -| Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` | +| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -168,7 +169,7 @@ oidc_providers: ### [Auth0][auth0] 1. Create a regular web application for Synapse -2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback` +2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback` 3. Add a rule to add the `preferred_username` claim.
    Code sample @@ -217,7 +218,7 @@ login mechanism needs an attribute to uniquely identify users, and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set. 1. Create a new OAuth application: https://github.com/settings/applications/new. -2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`. +2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`. Synapse config: @@ -262,13 +263,13 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse - public baseurl]/_synapse/oidc/callback`. + public baseurl]/_synapse/client/oidc/callback`. ### Twitch 1. Setup a developer account on [Twitch](https://dev.twitch.tv/) 2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) -3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback` +3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/client/oidc/callback` Synapse config: @@ -290,7 +291,7 @@ oidc_providers: 1. Create a [new application](https://gitlab.com/profile/applications). 2. Add the `read_user` and `openid` scopes. -3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback` +3. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback` Synapse config: @@ -323,7 +324,7 @@ one so requires a little more configuration. 2. Once the app is created, add "Facebook Login" and choose "Web". You don't need to go through the whole form here. 3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings". - * Add `[synapse public baseurl]/_synapse/oidc/callback` as an OAuth Redirect + * Add `[synapse public baseurl]/_synapse/client/oidc/callback` as an OAuth Redirect URL. 4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID" and "App Secret" for use below. diff --git a/docs/workers.md b/docs/workers.md index c36549c621..c4a6c79238 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -266,7 +266,7 @@ using): ^/_synapse/client/sso_register$ # OpenID Connect requests. - ^/_synapse/oidc/callback$ + ^/_synapse/client/oidc/callback$ # SAML requests. ^/_matrix/saml2/authn_response$ diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index bb122ef182..4c24c50629 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,7 +53,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ca647fa78f..71008ec50d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -102,7 +102,7 @@ class OidcHandler: ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call @@ -643,7 +643,7 @@ class OidcProvider: - ``client_id``: the client ID set in ``oidc_config.client_id`` - ``response_type``: ``code`` - - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string @@ -684,7 +684,7 @@ class OidcProvider: request.addCookie( SESSION_COOKIE_NAME, cookie, - path="/_synapse/oidc", + path="/_synapse/client/oidc", max_age="3600", httpOnly=True, sameSite="lax", @@ -705,7 +705,7 @@ class OidcProvider: async def handle_oidc_callback( self, request: SynapseRequest, session_data: "OidcSessionData", code: str ) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback By this time we have already validated the session on the synapse side, and now need to do the provider-specific operations. This includes: diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py deleted file mode 100644 index d958dd65bb..0000000000 --- a/synapse/rest/oidc/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Quentin Gliech -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from twisted.web.resource import Resource - -from synapse.rest.oidc.callback_resource import OIDCCallbackResource - -logger = logging.getLogger(__name__) - - -class OIDCResource(Resource): - def __init__(self, hs): - Resource.__init__(self) - self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py deleted file mode 100644 index f7a0bc4bdb..0000000000 --- a/synapse/rest/oidc/callback_resource.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Quentin Gliech -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from synapse.http.server import DirectServeHtmlResource - -logger = logging.getLogger(__name__) - - -class OIDCCallbackResource(DirectServeHtmlResource): - isLeaf = 1 - - def __init__(self, hs): - super().__init__() - self._oidc_handler = hs.get_oidc_handler() - - async def _async_render_GET(self, request): - await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 02310c1900..381baf9729 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -47,9 +47,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. if hs.config.oidc_enabled: - from synapse.rest.oidc import OIDCResource + from synapse.rest.synapse.client.oidc import OIDCResource - resources["/_synapse/oidc"] = OIDCResource(hs) + resources["/_synapse/client/oidc"] = OIDCResource(hs) if hs.config.saml2_enabled: from synapse.rest.saml2 import SAML2Resource diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py new file mode 100644 index 0000000000..64c0deb75d --- /dev/null +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource + +logger = logging.getLogger(__name__) + + +class OIDCResource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"callback", OIDCCallbackResource(hs)) + + +__all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py new file mode 100644 index 0000000000..f7a0bc4bdb --- /dev/null +++ b/synapse/rest/synapse/client/oidc/callback_resource.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.http.server import DirectServeHtmlResource + +logger = logging.getLogger(__name__) + + +class OIDCCallbackResource(DirectServeHtmlResource): + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_GET(self, request): + await self._oidc_handler.handle_oidc_callback(request) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index d8f90b9a80..ad20400b1d 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -40,7 +40,7 @@ ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" -CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] AUTHORIZATION_ENDPOINT = ISSUER + "authorize" @@ -58,12 +58,6 @@ COMMON_CONFIG = { } -# The cookie name and path don't really matter, just that it has to be coherent -# between the callback & redirect handlers. -COOKIE_NAME = b"oidc_session" -COOKIE_PATH = "/_synapse/oidc" - - class TestMappingProvider: @staticmethod def parse_config(config): @@ -340,8 +334,11 @@ class OidcHandlerTestCase(HomeserverTestCase): # For some reason, call.args does not work with python3.5 args = calls[0][0] kwargs = calls[0][1] - self.assertEqual(args[0], COOKIE_NAME) - self.assertEqual(kwargs["path"], COOKIE_PATH) + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + self.assertEqual(args[0], b"oidc_session") + self.assertEqual(kwargs["path"], "/_synapse/client/oidc") cookie = args[1] macaroon = pymacaroons.Macaroon.deserialize(cookie) -- cgit 1.5.1 From b60bb28bbc3d916586a913970298baba483efc1f Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 2 Feb 2021 04:16:29 -0700 Subject: Add an admin API to get the current room state (#9168) This could arguably replace the existing admin API for `/members`, however that is out of scope of this change. This sort of endpoint is ideal for moderation use cases as well as other applications, such as needing to retrieve various bits of information about a room to perform a task (like syncing power levels between two places). This endpoint exposes nothing more than an admin would be able to access with a `select *` query on their database. --- changelog.d/9168.feature | 1 + docs/admin_api/rooms.md | 30 ++++++++++++++++++++++++++++++ synapse/handlers/message.py | 2 +- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/rest/admin/test_room.py | 15 +++++++++++++++ 6 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9168.feature (limited to 'tests') diff --git a/changelog.d/9168.feature b/changelog.d/9168.feature new file mode 100644 index 0000000000..8be1950eee --- /dev/null +++ b/changelog.d/9168.feature @@ -0,0 +1 @@ +Add an admin API for retrieving the current room state of a room. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index f34cec1ff7..3832b36407 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -368,6 +368,36 @@ Response: } ``` +# Room State API + +The Room State admin API allows server admins to get a list of all state events in a room. + +The response includes the following fields: + +* `state` - The current state of the room at the time of request. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms//state + +{} +``` + +Response: + +```json +{ + "state": [ + {"type": "m.room.create", "state_key": "", "etc": true}, + {"type": "m.room.power_levels", "state_key": "", "etc": true}, + {"type": "m.room.name", "state_key": "", "etc": true} + ] +} +``` + # Delete Room API The Delete Room admin API allows server admins to remove rooms from server diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2a7d567fa..a15336bf00 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -174,7 +174,7 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False + self.storage, user_id, last_events, filter_send_to_client=False, ) event = last_events[0] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 57e0a10837..f5c5d164f9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin.rooms import ( MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, + RoomStateRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -213,6 +214,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f14915d47e..3e57e6a4d0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -292,6 +292,45 @@ class RoomMembersRestServlet(RestServlet): return 200, ret +class RoomStateRestServlet(RestServlet): + """ + Get full state within a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + event_ids = await self.store.get_current_state_ids(room_id) + events = await self.store.get_events(event_ids.values()) + now = self.clock.time_msec() + room_state = await self._event_serializer.serialize_events( + events.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_aggregations=False, + ) + ret = {"state": room_state} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a0f32c5512..7c47aa7e0a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) + def test_room_state(self): + """Test that room state can be requested correctly""" + # Create two test rooms + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("state", channel.json_body) + # testing that the state events match is painful and not done here. We assume that + # the create_room already does the right thing, so no need to verify that we got + # the state events it created. + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From ff55300b916b09fef83feb1f8ecaa07d90879c57 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2021 10:17:37 +0000 Subject: Honour ratelimit flag for application services for invite ratelimiting (#9302) --- changelog.d/9302.bugfix | 1 + synapse/handlers/federation.py | 4 +++- synapse/handlers/room_member.py | 12 +++++++--- tests/handlers/test_federation.py | 47 --------------------------------------- 4 files changed, 13 insertions(+), 51 deletions(-) create mode 100644 changelog.d/9302.bugfix (limited to 'tests') diff --git a/changelog.d/9302.bugfix b/changelog.d/9302.bugfix new file mode 100644 index 0000000000..c1cdea52a3 --- /dev/null +++ b/changelog.d/9302.bugfix @@ -0,0 +1 @@ +Fix new ratelimiting for invites to respect the `ratelimit` flag on application services. Introduced in v1.27.0rc1. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dbdfd56ff5..eddc7582d0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1619,7 +1619,9 @@ class FederationHandler(BaseHandler): # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - member_handler.ratelimit_invite(event.room_id, event.state_key) + # We don't rate limit based on room ID, as that should be done by + # sending server. + member_handler.ratelimit_invite(None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d335da6f19..a5da97cfe0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -155,10 +155,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() - def ratelimit_invite(self, room_id: str, invitee_user_id: str): + def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): """Ratelimit invites by room and by target user. + + If room ID is missing then we just rate limit by target user. """ - self._invites_per_room_limiter.ratelimit(room_id) + if room_id: + self._invites_per_room_limiter.ratelimit(room_id) + self._invites_per_user_limiter.ratelimit(invitee_user_id) async def _local_membership_update( @@ -406,7 +410,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - self.ratelimit_invite(room_id, target_id) + # Don't ratelimit application services. + if not requester.app_service or requester.app_service.is_rate_limited(): + self.ratelimit_invite(room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 74503112f5..983e368592 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -191,53 +191,6 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - @unittest.override_config( - {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invite_by_room_ratelimit(self): - """Tests that invites from federation in a room are actually rate-limited. - """ - other_server = "otherserver" - other_user = "@otheruser:" + other_server - - # create the room - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) - room_version = self.get_success(self.store.get_room_version(room_id)) - - def create_invite_for(local_user): - return event_from_pdu_json( - { - "type": EventTypes.Member, - "content": {"membership": "invite"}, - "room_id": room_id, - "sender": other_user, - "state_key": local_user, - "depth": 32, - "prev_events": [], - "auth_events": [], - "origin_server_ts": self.clock.time_msec(), - }, - room_version, - ) - - for i in range(3): - self.get_success( - self.handler.on_invite_request( - other_server, - create_invite_for("@user-%d:test" % (i,)), - room_version, - ) - ) - - self.get_failure( - self.handler.on_invite_request( - other_server, create_invite_for("@user-4:test"), room_version, - ), - exc=LimitExceededError, - ) - @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) -- cgit 1.5.1 From e40d88cff3cca3d5186d5f623ad1107bc403d69b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 11 Feb 2021 11:16:54 -0500 Subject: Backout changes for automatically calculating the public baseurl. (#9313) This breaks some people's configurations (if their Client-Server API is not accessed via port 443). --- changelog.d/9313.bugfix | 1 + docs/sample_config.yaml | 20 +++++++++----------- synapse/api/urls.py | 2 ++ synapse/config/cas.py | 16 +++++++++------- synapse/config/emailconfig.py | 8 ++++++++ synapse/config/oidc_config.py | 5 ++++- synapse/config/registration.py | 21 +++++++++++++++++---- synapse/config/saml2_config.py | 2 ++ synapse/config/server.py | 13 ++++--------- synapse/config/sso.py | 13 ++++++++----- synapse/handlers/identity.py | 4 ++++ synapse/rest/well_known.py | 4 ++++ synapse/util/templates.py | 15 ++++++++++++--- tests/rest/client/v1/test_login.py | 4 +++- tests/rest/test_well_known.py | 9 +++++++++ tests/utils.py | 1 + 16 files changed, 97 insertions(+), 41 deletions(-) create mode 100644 changelog.d/9313.bugfix (limited to 'tests') diff --git a/changelog.d/9313.bugfix b/changelog.d/9313.bugfix new file mode 100644 index 0000000000..f578fd13dd --- /dev/null +++ b/changelog.d/9313.bugfix @@ -0,0 +1 @@ +Do not automatically calculate `public_baseurl` since it can be wrong in some situations. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 236abd9a3f..d395da11b4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -74,10 +74,6 @@ pid_file: DATADIR/homeserver.pid # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # -# If this is left unset, it defaults to 'https:///'. (Note that -# that will not work unless you configure Synapse or a reverse-proxy to listen -# on port 443.) -# #public_baseurl: https://example.com/ # Set the soft limit on the number of file descriptors synapse can use @@ -1169,9 +1165,8 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. # #renew_at: 1w @@ -1262,7 +1257,8 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client.) +# (By default, no suggestion is made, so it is left up to the client. +# This setting is ignored unless public_baseurl is also set.) # #default_identity_server: https://matrix.org @@ -1287,6 +1283,8 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # +# If a delegate is specified, the config option public_baseurl must also be filled out. +# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1938,9 +1936,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # The login fallback page (used by clients that don't natively support the - # required login flows) is automatically whitelisted in addition to any URLs - # in this list. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index e36aeef31f..6379c86dde 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,6 +42,8 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") + if hs_config.public_baseurl is None: + raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/cas.py b/synapse/config/cas.py index b226890c2a..aaa7eba110 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError class CasConfig(Config): @@ -30,13 +30,15 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - public_base_url = cas_config.get("service_url") or self.public_baseurl - if public_base_url[-1] != "/": - public_base_url += "/" + + # The public baseurl is required because it is used by the redirect + # template. + public_baseurl = self.public_baseurl + if not public_baseurl: + raise ConfigError("cas_config requires a public_baseurl to be set") + # TODO Update this to a _synapse URL. - self.cas_service_url = ( - public_base_url + "_matrix/client/r0/login/cas/ticket" - ) + self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 6a487afd34..d4328c46b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,6 +166,11 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") + # public_baseurl is required to build password reset and validation links that + # will be emailed to users + if config.get("public_baseurl") is None: + missing.append("public_baseurl") + if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -264,6 +269,9 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") + if config.get("public_baseurl") is None: + missing.append("public_baseurl") + if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4c24c50629..4d0f24a9d5 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,7 +53,10 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" + public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("oidc_config requires a public_baseurl to be set") + self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ac48913a0b..eb650af7fb 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,6 +49,10 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 + if self.renew_by_email_enabled: + if "public_baseurl" not in synapse_config: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + template_dir = config.get("template_dir") if not template_dir: @@ -105,6 +109,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -227,9 +238,8 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. # #renew_at: 1w @@ -320,7 +330,8 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client.) + # (By default, no suggestion is made, so it is left up to the client. + # This setting is ignored unless public_baseurl is also set.) # #default_identity_server: https://matrix.org @@ -345,6 +356,8 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # + # If a delegate is specified, the config option public_baseurl must also be filled out. + # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index ad865a667f..7226abd829 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,6 +189,8 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 47a0370173..5d72cf2d82 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,11 +161,7 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( - self.server_name, - ) - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" + self.public_baseurl = config.get("public_baseurl") # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -321,6 +317,9 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + if self.public_baseurl is not None: + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -748,10 +747,6 @@ class ServerConfig(Config): # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see # 'listeners' below). # - # If this is left unset, it defaults to 'https:///'. (Note that - # that will not work unless you configure Synapse or a reverse-proxy to listen - # on port 443.) - # #public_baseurl: https://example.com/ # Set the soft limit on the number of file descriptors synapse can use diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 6c60c6fea4..19bdfd462b 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,8 +64,11 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -83,9 +86,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # The login fallback page (used by clients that don't natively support the - # required login flows) is automatically whitelisted in addition to any URLs - # in this list. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 4f7137539b..8fc1e8b91c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -504,6 +504,10 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") + # It is already checked that public_baseurl is configured since this code + # should only be used if account_threepid_delegate_msisdn is true. + assert self.hs.config.public_baseurl + # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 241fe746d9..f591cc6c5c 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,6 +34,10 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): + # if we don't have a public_baseurl, we can't help much here. + if self._config.public_baseurl is None: + return None + result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/synapse/util/templates.py b/synapse/util/templates.py index 7e5109d206..392dae4a40 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -17,7 +17,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Iterable, Union +from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union import jinja2 @@ -74,14 +74,23 @@ def build_jinja_env( return env -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: +def _create_mxc_to_http_filter( + public_baseurl: Optional[str], +) -> Callable[[str, int, int, str], str]: """Create and return a jinja2 filter that converts MXC urls to HTTP Args: public_baseurl: The public, accessible base URL of the homeserver """ - def mxc_to_http_filter(value, width, height, resize_method="crop"): + def mxc_to_http_filter( + value: str, width: int, height: int, resize_method: str = "crop" + ) -> str: + if not public_baseurl: + raise RuntimeError( + "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs." + ) + if value[0:6] != "mxc://": return "" diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 66dfdaffbc..bfcb786af8 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -672,10 +672,12 @@ class CASTestCase(unittest.HomeserverTestCase): self.redirect_path = "_synapse/client/login/sso/redirect/confirm" config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) config["cas_config"] = { "enabled": True, "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", } cas_user_id = "username" diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index c5e44af9f7..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,3 +40,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) + + def test_well_known_no_public_baseurl(self): + self.hs.config.public_baseurl = None + + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 68033d7535..840b657f82 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,6 +159,7 @@ def default_config(name, parse=False): }, "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, + "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1