diff --git a/changelog.d/8802.doc b/changelog.d/8802.doc
new file mode 100644
index 0000000000..580c4281f8
--- /dev/null
+++ b/changelog.d/8802.doc
@@ -0,0 +1 @@
+Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules.
diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix
new file mode 100644
index 0000000000..8ddfbf31ce
--- /dev/null
+++ b/changelog.d/8821.bugfix
@@ -0,0 +1 @@
+Apply the `federation_ip_range_blacklist` to push and key revocation requests.
diff --git a/changelog.d/8827.bugfix b/changelog.d/8827.bugfix
new file mode 100644
index 0000000000..18195680d3
--- /dev/null
+++ b/changelog.d/8827.bugfix
@@ -0,0 +1 @@
+Fix bug where we might not correctly calculate the current state for rooms with multiple extremities.
diff --git a/changelog.d/8837.bugfix b/changelog.d/8837.bugfix
new file mode 100644
index 0000000000..b2977d0c31
--- /dev/null
+++ b/changelog.d/8837.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix.
diff --git a/changelog.d/8858.bugfix b/changelog.d/8858.bugfix
new file mode 100644
index 0000000000..0d58cb9abc
--- /dev/null
+++ b/changelog.d/8858.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
diff --git a/changelog.d/8861.misc b/changelog.d/8861.misc
new file mode 100644
index 0000000000..9821f804cf
--- /dev/null
+++ b/changelog.d/8861.misc
@@ -0,0 +1 @@
+Remove some unnecessary stubbing from unit tests.
diff --git a/changelog.d/8864.misc b/changelog.d/8864.misc
new file mode 100644
index 0000000000..a780883495
--- /dev/null
+++ b/changelog.d/8864.misc
@@ -0,0 +1 @@
+Remove unused `FakeResponse` class from unit tests.
diff --git a/changelog.d/8867.bugfix b/changelog.d/8867.bugfix
new file mode 100644
index 0000000000..f2414ff111
--- /dev/null
+++ b/changelog.d/8867.bugfix
@@ -0,0 +1 @@
+Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled.
diff --git a/changelog.d/8873.doc b/changelog.d/8873.doc
new file mode 100644
index 0000000000..0c2a043bd1
--- /dev/null
+++ b/changelog.d/8873.doc
@@ -0,0 +1 @@
+Fix an error in the documentation for the SAML username mapping provider.
diff --git a/contrib/prometheus/synapse-v2.rules b/contrib/prometheus/synapse-v2.rules
index 6ccca2daaf..7e405bf7f0 100644
--- a/contrib/prometheus/synapse-v2.rules
+++ b/contrib/prometheus/synapse-v2.rules
@@ -58,3 +58,21 @@ groups:
labels:
type: "PDU"
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
+
+ - record: synapse_storage_events_persisted_by_source_type
+ expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
+ labels:
+ type: remote
+ - record: synapse_storage_events_persisted_by_source_type
+ expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
+ labels:
+ type: local
+ - record: synapse_storage_events_persisted_by_source_type
+ expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
+ labels:
+ type: bridges
+ - record: synapse_storage_events_persisted_by_event_type
+ expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
+ - record: synapse_storage_events_persisted_by_origin
+ expr: sum without(type) (synapse_storage_events_persisted_events_sep)
+
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 394eb9a3ff..6dbccf5932 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -642,17 +642,19 @@ acme:
# - nyc.example.com
# - syd.example.com
-# Prevent federation requests from being sent to the following
-# blacklist IP address CIDR ranges. If this option is not specified, or
-# specified with an empty list, no ip range blacklist will be enforced.
+# Prevent outgoing requests from being sent to the following blacklisted IP address
+# CIDR ranges. If this option is not specified, or specified with an empty list,
+# no IP range blacklist will be enforced.
#
-# As of Synapse v1.4.0 this option also affects any outbound requests to identity
-# servers provided by user input.
+# The blacklist applies to the outbound requests for federation, identity servers,
+# push servers, and for checking key validitity for third-party invite events.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
-federation_ip_range_blacklist:
+# This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
+#
+ip_range_blacklist:
- '127.0.0.0/8'
- '10.0.0.0/8'
- '172.16.0.0/12'
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index ab2a648910..7714b1d844 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -116,11 +116,13 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
-* `__init__(self, parsed_config)`
+* `__init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
+ - `module_api` - a `synapse.module_api.ModuleApi` object which provides the
+ stable API available for extension modules.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890aa..aa12c74358 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -266,7 +266,6 @@ class GenericWorkerPresence(BasePresenceHandler):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
- self.http_client = hs.get_simple_http_client()
self._presence_enabled = hs.config.use_presence
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca54e..27ccf61c3c 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -36,22 +36,30 @@ class FederationConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
- self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", []
- )
+ ip_range_blacklist = config.get("ip_range_blacklist", [])
# Attempt to create an IPSet from the given ranges
try:
- self.federation_ip_range_blacklist = IPSet(
- self.federation_ip_range_blacklist
- )
-
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+ self.ip_range_blacklist = IPSet(ip_range_blacklist)
+ except Exception as e:
+ raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e)
+ # Always blacklist 0.0.0.0, ::
+ self.ip_range_blacklist.update(["0.0.0.0", "::"])
+
+ # The federation_ip_range_blacklist is used for backwards-compatibility
+ # and only applies to federation and identity servers. If it is not given,
+ # default to ip_range_blacklist.
+ federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", ip_range_blacklist
+ )
+ try:
+ self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
except Exception as e:
raise ConfigError(
"Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
)
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
federation_metrics_domains = config.get("federation_metrics_domains") or []
validate_config(
@@ -76,17 +84,19 @@ class FederationConfig(Config):
# - nyc.example.com
# - syd.example.com
- # Prevent federation requests from being sent to the following
- # blacklist IP address CIDR ranges. If this option is not specified, or
- # specified with an empty list, no ip range blacklist will be enforced.
+ # Prevent outgoing requests from being sent to the following blacklisted IP address
+ # CIDR ranges. If this option is not specified, or specified with an empty list,
+ # no IP range blacklist will be enforced.
#
- # As of Synapse v1.4.0 this option also affects any outbound requests to identity
- # servers provided by user input.
+ # The blacklist applies to the outbound requests for federation, identity servers,
+ # push servers, and for checking key validitity for third-party invite events.
#
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
- federation_ip_range_blacklist:
+ # This option replaces federation_ip_range_blacklist in Synapse v1.24.0.
+ #
+ ip_range_blacklist:
- '127.0.0.0/8'
- '10.0.0.0/8'
- '172.16.0.0/12'
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77cf9..f23eacc0d7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -578,7 +578,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
def __init__(self, hs):
super().__init__(hs)
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
async def get_keys(self, keys_to_fetch):
@@ -748,7 +748,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
def __init__(self, hs):
super().__init__(hs)
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
async def get_keys(self, keys_to_fetch):
"""
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4b6ab470d0..35e345ce70 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -845,7 +845,6 @@ class FederationHandlerRegistry:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..abe9168c78 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient:
def __init__(self, hs):
self.server_name = hs.hostname
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..434718ddfc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
- resource (TransportLayerServer): resource class to register to
+ resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..2e72298e05 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if hs.config.password_localdb_enabled:
+ if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
@@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
self._supported_login_types = login_types
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
-
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+ if await self.store.get_external_ids_by_user(user.to_string()):
+ ui_auth_types.add(LoginType.SSO)
+
+ # Our CAS impl does not (yet) correctly register users in user_external_ids,
+ # so always offer that if it's available.
+ if self.hs.config.cas.cas_enabled:
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..df82e60b33 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
+ # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
- # We create a blacklisting instance of SimpleHttpClient for contacting identity
- # servers specified by clients
+ # An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
- self.federation_http_client = hs.get_http_client()
+ self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
async def threepid_from_creds(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593f2..df7730078f 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
return _scheduler
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
return r
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+ """
+ A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+ addresses, to prevent DNS rebinding.
+ """
+
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
+ self._reactor = reactor
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ self._nameResolver = _IPBlacklistingResolver(
+ self._reactor, ip_whitelist, ip_blacklist
+ )
+
+ def __getattr__(self, attr: str) -> Any:
+ # Passthrough to the real reactor except for the DNS resolver.
+ if attr == "nameResolver":
+ return self._nameResolver
+ else:
+ return getattr(self._reactor, attr)
+
+
class BlacklistingAgentWrapper(Agent):
"""
An Agent wrapper which will prevent access to IP addresses being accessed
@@ -292,22 +321,11 @@ class SimpleHttpClient:
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
- real_reactor = hs.get_reactor()
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, self._ip_whitelist, self._ip_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
)
-
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
else:
self.reactor = hs.get_reactor()
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index e77f9587d0..3b756a7dc2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -16,7 +16,7 @@ import logging
import urllib.parse
from typing import List, Optional
-from netaddr import AddrFormatError, IPAddress
+from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
from twisted.internet import defer
@@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -70,6 +71,7 @@ class MatrixFederationAgent:
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
+ ip_blacklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
@@ -90,12 +92,18 @@ class MatrixFederationAgent:
self.user_agent = user_agent
if _well_known_resolver is None:
+ # Note that the name resolver has already been wrapped in a
+ # IPBlacklistingResolver by MatrixFederationHttpClient.
_well_known_resolver = WellKnownResolver(
self._reactor,
- agent=Agent(
+ agent=BlacklistingAgentWrapper(
+ Agent(
+ self._reactor,
+ pool=self._pool,
+ contextFactory=tls_client_options_factory,
+ ),
self._reactor,
- pool=self._pool,
- contextFactory=tls_client_options_factory,
+ ip_blacklist=ip_blacklist,
),
user_agent=self.user_agent,
)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4e27f93b7a..c962994727 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,11 +26,10 @@ import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
-from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
+from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -45,7 +44,7 @@ from synapse.api.errors import (
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
- IPBlacklistingResolver,
+ BlacklistingReactorWrapper,
encode_query_args,
readBodyToFile,
)
@@ -221,31 +220,22 @@ class MatrixFederationHttpClient:
self.signing_key = hs.signing_key
self.server_name = hs.hostname
- real_reactor = hs.get_reactor()
-
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, None, hs.config.federation_ip_range_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
)
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
-
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
self.agent = MatrixFederationAgent(
- self.reactor, tls_client_options_factory, user_agent
+ self.reactor,
+ tls_client_options_factory,
+ user_agent,
+ hs.config.federation_ip_range_blacklist,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eff0975b6a..0e845212a9 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -100,7 +100,7 @@ class HttpPusher:
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_proxied_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..90940ff185 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet):
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
+ if "mac" not in body:
+ raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
got_mac = body["mac"]
want_mac_builder = hmac.new(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a89ae6ddf9..9041e7ed76 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
- raise SynapseError(403, "Registration has been disabled")
+ raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# For regular registration, convert the provided username to lowercase
# before attempting to register it. This should mean that people who try
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..83beb02b05 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -66,7 +66,7 @@ class MediaRepository:
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
diff --git a/synapse/server.py b/synapse/server.py
index b017e3489f..9af759626e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -350,17 +350,46 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client with no special configuration.
+ """
return SimpleHttpClient(self)
@cache_in_self
def get_proxied_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client that uses configured HTTP(S) proxies.
+ """
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ )
+
+ @cache_in_self
+ def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
+ based on the IP range blacklist.
+ """
return SimpleHttpClient(
self,
+ ip_blacklist=self.config.ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
@cache_in_self
+ def get_federation_http_client(self) -> MatrixFederationHttpClient:
+ """
+ An HTTP client for federation.
+ """
+ tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
+ self.config
+ )
+ return MatrixFederationHttpClient(self, tls_client_options_factory)
+
+ @cache_in_self
def get_room_creation_handler(self) -> RoomCreationHandler:
return RoomCreationHandler(self)
@@ -515,13 +544,6 @@ class HomeServer(metaclass=abc.ABCMeta):
return PusherPool(self)
@cache_in_self
- def get_http_client(self) -> MatrixFederationHttpClient:
- tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
- self.config
- )
- return MatrixFederationHttpClient(self, tls_client_options_factory)
-
- @cache_in_self
def get_media_repository_resource(self) -> MediaRepositoryResource:
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..ffc504ce77 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -252,9 +252,88 @@ async def _get_auth_chain_difference(
Set of event IDs
"""
- difference = await state_res_store.get_auth_chain_difference(
- [set(state_set.values()) for state_set in state_sets]
- )
+ # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+ # all events passed to it (and their auth chains) have been persisted
+ # previously. This is not the case for any events in the `event_map`, and so
+ # we need to manually handle those events.
+ #
+ # We do this by:
+ # 1. calculating the auth chain difference for the state sets based on the
+ # events in `event_map` alone
+ # 2. replacing any events in the state_sets that are also in `event_map`
+ # with their auth events (recursively), and then calling
+ # `store.get_auth_chain_difference` as normal
+ # 3. adding the results of 1 and 2 together.
+
+ # Map from event ID in `event_map` to their auth event IDs, and their auth
+ # event IDs if they appear in the `event_map`. This is the intersection of
+ # the event's auth chain with the events in the `event_map` *plus* their
+ # auth event IDs.
+ events_to_auth_chain = {} # type: Dict[str, Set[str]]
+ for event in event_map.values():
+ chain = {event.event_id}
+ events_to_auth_chain[event.event_id] = chain
+
+ to_search = [event]
+ while to_search:
+ for auth_id in to_search.pop().auth_event_ids():
+ chain.add(auth_id)
+ auth_event = event_map.get(auth_id)
+ if auth_event:
+ to_search.append(auth_event)
+
+ # We now a) calculate the auth chain difference for the unpersisted events
+ # and b) work out the state sets to pass to the store.
+ #
+ # Note: If the `event_map` is empty (which is the common case), we can do a
+ # much simpler calculation.
+ if event_map:
+ # The list of state sets to pass to the store, where each state set is a set
+ # of the event ids making up the state. This is similar to `state_sets`,
+ # except that (a) we only have event ids, not the complete
+ # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+ # unpersisted events and replaced them with the persisted events in
+ # their auth chain.
+ state_sets_ids = [] # type: List[Set[str]]
+
+ # For each state set, the unpersisted event IDs reachable (by their auth
+ # chain) from the events in that set.
+ unpersisted_set_ids = [] # type: List[Set[str]]
+
+ for state_set in state_sets:
+ set_ids = set() # type: Set[str]
+ state_sets_ids.append(set_ids)
+
+ unpersisted_ids = set() # type: Set[str]
+ unpersisted_set_ids.append(unpersisted_ids)
+
+ for event_id in state_set.values():
+ event_chain = events_to_auth_chain.get(event_id)
+ if event_chain is not None:
+ # We have an event in `event_map`. We add all the auth
+ # events that it references (that aren't also in `event_map`).
+ set_ids.update(e for e in event_chain if e not in event_map)
+
+ # We also add the full chain of unpersisted event IDs
+ # referenced by this state set, so that we can work out the
+ # auth chain difference of the unpersisted events.
+ unpersisted_ids.update(e for e in event_chain if e in event_map)
+ else:
+ set_ids.add(event_id)
+
+ # The auth chain difference of the unpersisted events of the state sets
+ # is calculated by taking the difference between the union and
+ # intersections.
+ union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+ intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+ difference_from_event_map = union - intersection # type: Collection[str]
+ else:
+ difference_from_event_map = ()
+ state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
+ difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
+ difference.update(difference_from_event_map)
return difference
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c26..ff96c34c2e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id",
)
+ async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+ """Look up external ids for the given user
+
+ Args:
+ mxid: the MXID to be looked up
+
+ Returns:
+ Tuples of (auth_provider, external_id)
+ """
+ res = await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ )
+ return [(r["auth_provider"], r["external_id"]) for r in res]
+
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
+ self.db_pool.updates.register_background_index_update(
+ "user_external_ids_user_id_idx",
+ index_name="user_external_ids_user_id_idx",
+ table="user_external_ids",
+ columns=["user_id"],
+ unique=False,
+ )
+
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index c98ae75974..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
import jsonschema
from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
from synapse.events import make_event_from_dict
from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
user_localpart = "test_user"
@@ -42,19 +40,9 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase):
- @defer.inlineCallbacks
def setUp(self):
- self.mock_federation_resource = MockHttpResource()
-
- self.mock_http_client = Mock(spec=[])
- self.mock_http_client.put_json = DeferredMockCallable()
-
- hs = yield setup_test_homeserver(
- self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
- )
-
+ hs = setup_test_homeserver(self.addCleanup)
self.filtering = hs.get_filtering()
-
self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 40abe9d72d..43fef5d64a 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserver_to_use=GenericWorkerServer
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index ea3be95cf1..b260ab734d 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserver_to_use=GenericWorkerServer
+ federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserver_to_use=SynapseHomeServer
+ federation_http_client=None, homeserver_to_use=SynapseHomeServer
)
return hs
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 697916a019..d146f2254f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- hs = self.setup_test_homeserver(http_client=self.http_client)
+ hs = self.setup_test_homeserver(federation_http_client=self.http_client)
return hs
def test_get_keys_from_server(self):
@@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
}
]
- return self.setup_test_homeserver(http_client=self.http_client, config=config)
+ return self.setup_test_homeserver(
+ federation_http_client=self.http_client, config=config
+ )
def build_perspectives_response(
self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..770d225ed5 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.mock_registry.register_query_handler = register_query_handler
hs = self.setup_test_homeserver(
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..d0452e1490 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(http_client=None)
+ hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastore()
return hs
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..1d99a45436 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
from mock import Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
+from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config
-
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "server", http_client=None, federation_sender=Mock()
+ "server", federation_http_client=None, federation_sender=Mock()
)
return hs
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..919547556b 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..f21de958f1 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
notifier=Mock(),
- http_client=mock_federation_client,
+ federation_http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..626acdcaa3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging
from mock import Mock
import treq
+from netaddr import IPSet
from service_identity import VerificationError
from zope.interface import implementer
@@ -103,6 +104,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
+ ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -736,6 +738,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
+ ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index f118430309..e8cea39c83 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -49,7 +49,9 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
+ hs = self.setup_test_homeserver(
+ config=config, proxied_blacklisted_http_client=m
+ )
return hs
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..3379189785 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
- http_client=None,
+ federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
@@ -264,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
- useful to e.g. pass some mocks for things like `http_client`
+ useful to e.g. pass some mocks for things like `federation_http_client`
Returns:
The new worker HomeServer instance.
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
- http_client=mock_client,
+ federation_http_client=mock_client,
)
user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..f894bcd6e7 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
- proxied_http_client=http_client_mock,
+ proxied_blacklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock1,
+ proxied_blacklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock2,
+ proxied_blacklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 4f76f8f768..67d8878395 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -210,7 +210,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
}
config["media_storage_providers"] = [provider_config]
- hs = self.setup_test_homeserver(config=config, http_client=client)
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 5d5c24d01c..11cd8efe21 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
"red",
- http_client=None,
+ federation_http_client=None,
federation_client=Mock(),
presence_handler=presence_handler,
)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 383a9eafac..2a3b483eaf 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
- http_client=None,
+ federation_http_client=None,
resource_for_client=self.mock_resource,
federation=Mock(),
federation_client=Mock(),
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 49f1073c88..e67de41c18 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -45,7 +45,7 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock(),
+ "red", federation_http_client=None, federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index bbd30f594b..ae0207366b 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock(),
+ "red", federation_http_client=None, federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..5a18af8d34 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
# limitations under the License.
import json
+import re
import time
+import urllib.parse
from typing import Any, Dict, Optional
+from mock import patch
+
import attr
from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.types import JsonDict
from tests.server import FakeSite, make_request
+from tests.test_utils import FakeResponse
@attr.s
@@ -344,3 +350,111 @@ class RestHelper:
)
return channel.json_body
+
+ def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ """Log in (as a new user) via OIDC
+
+ Returns the result of the final token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+ client_redirect_url = "https://x"
+
+ # first hit the redirect url (which will issue a cookie and state)
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+ )
+ # that will redirect to the OIDC IdP, but we skip that and go straight
+ # back to synapse's OIDC callback resource. However, we do need the "state"
+ # param that synapse passes to the IdP via query params, and the cookie that
+ # synapse passes to the client.
+ assert channel.code == 302
+ oauth_uri = channel.headers.getRawHeaders("Location")[0]
+ params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
+ redirect_uri = "%s?%s" % (
+ urllib.parse.urlparse(params["redirect_uri"][0]).path,
+ urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ )
+ cookies = {}
+ for h in channel.headers.getRawHeaders("Set-Cookie"):
+ parts = h.split(";")
+ k, v = parts[0].split("=", maxsplit=1)
+ cookies[k] = v
+
+ # before we hit the callback uri, stub out some methods in the http client so
+ # that we don't have to handle full HTTPS requests.
+
+ # (expected url, json response) pairs, in the order we expect them.
+ expected_requests = [
+ # first we get a hit to the token endpoint, which we tell to return
+ # a dummy OIDC access token
+ ("https://issuer.test/token", {"access_token": "TEST"}),
+ # and then one to the user_info endpoint, which returns our remote user id.
+ ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+ ]
+
+ async def mock_req(method: str, uri: str, data=None, headers=None):
+ (expected_uri, resp_obj) = expected_requests.pop(0)
+ assert uri == expected_uri
+ resp = FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ )
+ return resp
+
+ with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ # now hit the callback URI with the right params and a made-up code
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ redirect_uri,
+ custom_headers=[
+ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+ ],
+ )
+
+ # expect a confirmation page
+ assert channel.code == 200
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.result["body"].decode("utf-8"),
+ )
+ assert m
+ login_token = m.group(1)
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token and device id.
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ assert channel.code == 200
+ return channel.json_body
+
+
+# an 'oidc_config' suitable for login_with_oidc.
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "discover": False,
+ "issuer": "https://issuer.test",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://z",
+ "token_endpoint": "https://issuer.test/token",
+ "userinfo_endpoint": "https://issuer.test/userinfo",
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..ac67a9de29 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import List, Union
from twisted.internet.defer import succeed
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.oidc import OIDCResource
+from synapse.types import JsonDict, UserID
from tests import unittest
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+
+ # we enable OIDC as a way of testing SSO flows
+ oidc_config = {}
+ oidc_config.update(TEST_OIDC_CONFIG)
+ oidc_config["allow_existing_users"] = True
+
+ config["oidc_config"] = oidc_config
+ config["public_baseurl"] = "https://synapse.test"
+ return config
+
+ def create_resource_dict(self):
+ resource_dict = super().create_resource_dict()
+ # mount the OIDC resource at /_synapse/oidc
+ resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+ return resource_dict
+
def prepare(self, reactor, clock, hs):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
- def get_device_ids(self) -> List[str]:
+ def get_device_ids(self, access_token: str) -> List[str]:
# Get the list of devices so one can be deleted.
- request, channel = self.make_request(
- "GET", "devices", access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
-
- # Get the ID of the device.
- self.assertEqual(request.code, 200)
+ _, channel = self.make_request("GET", "devices", access_token=access_token,)
+ self.assertEqual(channel.code, 200)
return [d["device_id"] for d in channel.json_body["devices"]]
def delete_device(
- self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ self,
+ access_token: str,
+ device: str,
+ expected_response: int,
+ body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
request, channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=self.user_tok
+ "DELETE", "devices/" + device, body, access_token=access_token,
) # type: SynapseRequest, FakeChannel
# Ensure the response is sane.
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
- device_id = self.get_device_ids()[0]
+ device_id = self.get_device_ids(self.user_tok)[0]
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
+ self.user_tok,
device_id,
200,
{
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works.
"""
- device_id = self.get_device_ids()[0]
- channel = self.delete_device(device_id, 401)
+ device_id = self.get_device_ids(self.user_tok)[0]
+ channel = self.delete_device(self.user_tok, device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
+ self.user_tok,
device_id,
200,
{
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
- device_ids = self.get_device_ids()
+ device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
- device_ids = self.get_device_ids()
+ device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_ids[0], 401)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
# Grab the session
session = channel.json_body["session"]
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
self.delete_device(
+ self.user_tok,
device_ids[1],
403,
{
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
+
+ def test_does_not_offer_password_for_sso_user(self):
+ login_resp = self.helper.login_via_oidc("username")
+ user_tok = login_resp["access_token"]
+ device_id = login_resp["device_id"]
+
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(user_tok, device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+ def test_does_not_offer_sso_for_password_user(self):
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ device_ids = self.get_device_ids(self.user_tok)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+ def test_offers_both_flows_for_upgraded_user(self):
+ """A user that had a password and then logged in with SSO should get both flows
+ """
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ device_ids = self.get_device_ids(self.user_tok)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+ flows = channel.json_body["flows"]
+ # we have no particular expectations of ordering here
+ self.assertIn({"stages": ["m.login.password"]}, flows)
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ self.assertEqual(len(flows), 2)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8f0c2430e8..bcb21d0ced 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -121,6 +121,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+ self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
self.hs.config.macaroon_secret_key = "test"
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index fbcf8d5b86..5e90d656f7 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -39,7 +39,7 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
- return self.setup_test_homeserver(http_client=self.http_client)
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
def create_test_resource(self):
return create_resource_tree(
@@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
}
]
self.hs2 = self.setup_test_homeserver(
- http_client=self.http_client2, config=config
+ federation_http_client=self.http_client2, config=config
)
# wire up outbound POST /key/v2/query requests from hs2 so that they
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2a3b2a8f27..4c749f1a61 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -214,7 +214,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
}
config["media_storage_providers"] = [provider_config]
- hs = self.setup_test_homeserver(config=config, http_client=client)
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..529b6bcded 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,41 +18,15 @@ import re
from mock import patch
-import attr
-
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
-@attr.s
-class FakeResponse:
- version = attr.ib()
- code = attr.ib()
- phrase = attr.ib()
- headers = attr.ib()
- body = attr.ib()
- absoluteURI = attr.ib()
-
- @property
- def request(self):
- @attr.s
- class FakeTransport:
- absoluteURI = self.absoluteURI
-
- return FakeTransport()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..4faf32e335 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
@@ -258,6 +259,7 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
+ req.parseCookies()
req.requestReceived(method, path, b"1.1")
if await_result:
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..f5c6db900d 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+ _get_auth_chain_difference,
+ lexicographical_topological_sort,
+ resolve_events_with_store,
+)
from synapse.types import EventID
from tests import unittest
@@ -587,6 +591,128 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
+class AuthChainDifferenceTestCase(unittest.TestCase):
+ """We test that `_get_auth_chain_difference` correctly handles unpersisted
+ events.
+ """
+
+ def test_simple(self):
+ # Test getting the auth difference for a simple chain with a single
+ # unpersisted event:
+ #
+ # Unpersisted | Persisted
+ # |
+ # C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c}
+
+ state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {c.event_id})
+
+ def test_multiple_unpersisted_chain(self):
+ # Test getting the auth difference for a simple chain with multiple
+ # unpersisted events:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D -> C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, c.event_id})
+
+ def test_unpersisted_events_different_sets(self):
+ # Test getting the auth difference for with multiple unpersisted events
+ # in different branches:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D --> C -|-> B -> A
+ # E ----^ -|---^
+ # |
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ e = FakeEvent(
+ id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id, b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, e.event_id})
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastore()
return hs
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..71c21d8c75 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -217,6 +217,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
)
self.assertSetEqual(difference, {"a", "b", "d", "e"})
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d4f9e809db..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
-
from canonicaljson import json
from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["redaction_retention_period"] = "30d"
- return self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None, config=config
- )
+ return config
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index ff972daeaa..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
-
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None
- )
- return hs
-
def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 1ce4ea3a01..fa45f8b3b7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup,
- http_client=self.http_client,
+ federation_http_client=self.http_client,
clock=self.hs_clock,
reactor=self.reactor,
)
diff --git a/tests/test_server.py b/tests/test_server.py
index c387a85f2e..6b2d2f0401 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase):
self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
+ self.addCleanup,
+ federation_http_client=None,
+ clock=self.hs_clock,
+ reactor=self.reactor,
)
def test_handler_for_request(self):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..6873d45eb6 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,11 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
TV = TypeVar("TV")
@@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore
return cleanup
+
+
+@attr.s
+class FakeResponse:
+ """A fake twisted.web.IResponse object
+
+ there is a similar class at treq.test.test_response, but it lacks a `phrase`
+ attribute, and didn't support deliverBody until recently.
+ """
+
+ # HTTP response code
+ code = attr.ib(type=int)
+
+ # HTTP response phrase (eg b'OK' for a 200)
+ phrase = attr.ib(type=bytes)
+
+ # body of the response
+ body = attr.ib(type=bytes)
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..102b0a1f34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@@ -46,6 +46,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
"""
Create a the root resource for the test server.
- The default implementation creates a JsonResource and calls each function in
- `servlets` to register servletes against it
+ The default calls `self.create_resource_dict` and builds the resultant dict
+ into a tree.
"""
- resource = JsonResource(self.hs)
+ root_resource = Resource()
+ create_resource_tree(self.create_resource_dict(), root_resource)
+ return root_resource
- for servlet in self.servlets:
- servlet(self.hs, resource)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
- return resource
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ servlet_resource = JsonResource(self.hs)
+ for servlet in self.servlets:
+ servlet(self.hs, servlet_resource)
+ return {
+ "/_matrix/client": servlet_resource,
+ "/_synapse/admin": servlet_resource,
+ }
def default_config(self):
"""
@@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
diff --git a/tests/utils.py b/tests/utils.py
index c8d3ffbaba..977eeaf6ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
import time
import uuid
import warnings
-from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
@@ -34,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
@@ -42,7 +40,6 @@ from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
#
@@ -342,32 +339,9 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kwargs.get("resource_for_federation", None)
- if fed:
- register_federation_servlets(hs, fed)
-
return hs
-def register_federation_servlets(hs, resource):
- federation_server.register_servlets(
- hs,
- resource=resource,
- authenticator=federation_server.Authenticator(hs),
- ratelimiter=FederationRateLimiter(
- hs.get_clock(), config=hs.config.rc_federation
- ),
- )
-
-
-def get_mock_call_args(pattern_func, mock_func):
- """ Return the arguments the mock function was called with interpreted
- by the pattern functions argument list.
- """
- invoked_args, invoked_kargs = mock_func.call_args
- return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
@@ -553,86 +527,6 @@ class MockClock:
return d
-def _format_call(args, kwargs):
- return ", ".join(
- ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
- )
-
-
-class DeferredMockCallable:
- """A callable instance that stores a set of pending call expectations and
- return values for them. It allows a unit test to assert that the given set
- of function calls are eventually made, by awaiting on them to be called.
- """
-
- def __init__(self):
- self.expectations = []
- self.calls = []
-
- def __call__(self, *args, **kwargs):
- self.calls.append((args, kwargs))
-
- if not self.expectations:
- raise ValueError(
- "%r has no pending calls to handle call(%s)"
- % (self, _format_call(args, kwargs))
- )
-
- for (call, result, d) in self.expectations:
- if args == call[1] and kwargs == call[2]:
- d.callback(None)
- return result
-
- failure = AssertionError(
- "Was not expecting call(%s)" % (_format_call(args, kwargs))
- )
-
- for _, _, d in self.expectations:
- try:
- d.errback(failure)
- except Exception:
- pass
-
- raise failure
-
- def expect_call_and_return(self, call, result):
- self.expectations.append((call, result, defer.Deferred()))
-
- @defer.inlineCallbacks
- def await_calls(self, timeout=1000):
- deferred = defer.DeferredList(
- [d for _, _, d in self.expectations], fireOnOneErrback=True
- )
-
- timer = reactor.callLater(
- timeout / 1000,
- deferred.errback,
- AssertionError(
- "%d pending calls left: %s"
- % (
- len([e for e in self.expectations if not e[2].called]),
- [e for e in self.expectations if not e[2].called],
- )
- ),
- )
-
- yield deferred
-
- timer.cancel()
-
- self.calls = []
-
- def assert_had_no_calls(self):
- if self.calls:
- calls = self.calls
- self.calls = []
-
- raise AssertionError(
- "Expected not to received any calls, got:\n"
- + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
- )
-
-
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
"""
|