diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index f233854941..c1b8e98ae0 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -90,6 +90,8 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
+ self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
+
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
@@ -383,6 +385,14 @@ class SAML2Config(Config):
# value: "staff"
# - attribute: department
# value: "sales"
+
+ # If the metadata XML contains multiple IdP entities then the `idp_entityid`
+ # option must be set to the entity to redirect users to.
+ #
+ # Most deployments only have a single IdP entity and so should omit this
+ # option.
+ #
+ #idp_entityid: 'https://our_idp/entityid'
""" % {
"config_dir_path": config_dir_path
}
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 23278e36b7..4b6ab470d0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -49,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
+from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -391,7 +392,7 @@ class FederationServer(FederationBase):
TRANSACTION_CONCURRENCY_LIMIT,
)
- async def on_context_state_request(
+ async def on_room_state_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
@@ -514,11 +515,12 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request(
- self, origin: str, content: JsonDict, room_id: str
+ self, origin: str, content: JsonDict
) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content)
- room_version = await self.store.get_room_version(room_id)
+ assert_params_in_dict(content, ["room_id"])
+ room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
@@ -547,12 +549,11 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- async def on_send_leave_request(
- self, origin: str, content: JsonDict, room_id: str
- ) -> dict:
+ async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
logger.debug("on_send_leave_request: content: %s", content)
- room_version = await self.store.get_room_version(room_id)
+ assert_params_in_dict(content, ["room_id"])
+ room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
@@ -748,12 +749,8 @@ class FederationServer(FederationBase):
)
return ret
- async def on_exchange_third_party_invite_request(
- self, room_id: str, event_dict: Dict
- ):
- ret = await self.handler.on_exchange_third_party_invite_request(
- room_id, event_dict
- )
+ async def on_exchange_third_party_invite_request(self, event_dict: Dict):
+ ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret
async def check_server_matches_acl(self, server_name: str, room_id: str):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 1f74c504a3..a9b73d713a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -441,13 +441,13 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateV1Servlet(BaseFederationServlet):
- PATH = "/state/(?P<context>[^/]*)/?"
+ PATH = "/state/(?P<room_id>[^/]*)/?"
- # This is when someone asks for all data for a given context.
- async def on_GET(self, origin, content, query, context):
- return await self.handler.on_context_state_request(
+ # This is when someone asks for all data for a given room.
+ async def on_GET(self, origin, content, query, room_id):
+ return await self.handler.on_room_state_request(
origin,
- context,
+ room_id,
parse_string_from_args(query, "event_id", None, required=False),
)
@@ -464,16 +464,16 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet):
- PATH = "/backfill/(?P<context>[^/]*)/?"
+ PATH = "/backfill/(?P<room_id>[^/]*)/?"
- async def on_GET(self, origin, content, query, context):
+ async def on_GET(self, origin, content, query, room_id):
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
if not limit:
return 400, {"error": "Did not include limit param"}
- return await self.handler.on_backfill_request(origin, context, versions, limit)
+ return await self.handler.on_backfill_request(origin, room_id, versions, limit)
class FederationQueryServlet(BaseFederationServlet):
@@ -488,9 +488,9 @@ class FederationQueryServlet(BaseFederationServlet):
class FederationMakeJoinServlet(BaseFederationServlet):
- PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
+ PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, _content, query, context, user_id):
+ async def on_GET(self, origin, _content, query, room_id, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
@@ -512,16 +512,16 @@ class FederationMakeJoinServlet(BaseFederationServlet):
supported_versions = ["1"]
content = await self.handler.on_make_join_request(
- origin, context, user_id, supported_versions=supported_versions
+ origin, room_id, user_id, supported_versions=supported_versions
)
return 200, content
class FederationMakeLeaveServlet(BaseFederationServlet):
- PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
+ PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, context, user_id):
- content = await self.handler.on_make_leave_request(origin, context, user_id)
+ async def on_GET(self, origin, content, query, room_id, user_id):
+ content = await self.handler.on_make_leave_request(origin, room_id, user_id)
return 200, content
@@ -529,7 +529,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_leave_request(origin, content, room_id)
+ content = await self.handler.on_send_leave_request(origin, content)
return 200, (200, content)
@@ -539,43 +539,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_leave_request(origin, content, room_id)
+ content = await self.handler.on_send_leave_request(origin, content)
return 200, content
class FederationEventAuthServlet(BaseFederationServlet):
- PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_GET(self, origin, content, query, context, event_id):
- return await self.handler.on_event_auth(origin, context, event_id)
+ async def on_GET(self, origin, content, query, room_id, event_id):
+ return await self.handler.on_event_auth(origin, room_id, event_id)
class FederationV1SendJoinServlet(BaseFederationServlet):
- PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, context, event_id):
- # TODO(paul): assert that context/event_id parsed from path actually
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ # TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
- content = await self.handler.on_send_join_request(origin, content, context)
+ content = await self.handler.on_send_join_request(origin, content)
return 200, (200, content)
class FederationV2SendJoinServlet(BaseFederationServlet):
- PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, context, event_id):
- # TODO(paul): assert that context/event_id parsed from path actually
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ # TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
- content = await self.handler.on_send_join_request(origin, content, context)
+ content = await self.handler.on_send_join_request(origin, content)
return 200, content
class FederationV1InviteServlet(BaseFederationServlet):
- PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, context, event_id):
+ async def on_PUT(self, origin, content, query, room_id, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
@@ -590,12 +590,12 @@ class FederationV1InviteServlet(BaseFederationServlet):
class FederationV2InviteServlet(BaseFederationServlet):
- PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, context, event_id):
- # TODO(paul): assert that context/event_id parsed from path actually
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ # TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
room_version = content["room_version"]
@@ -617,9 +617,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
- content = await self.handler.on_exchange_third_party_invite_request(
- room_id, content
- )
+ content = await self.handler.on_exchange_third_party_invite_request(content)
return 200, content
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9fc8444228..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
- logger.info("Checking interested services for %s" % (stream_key))
+ logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4fc619e802..f27dd84213 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -55,6 +55,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers._base import BaseHandler
+from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -2698,7 +2699,7 @@ class FederationHandler(BaseHandler):
)
async def on_exchange_third_party_invite_request(
- self, room_id: str, event_dict: JsonDict
+ self, event_dict: JsonDict
) -> None:
"""Handle an exchange_third_party_invite request from a remote server
@@ -2706,12 +2707,11 @@ class FederationHandler(BaseHandler):
into a normal m.room.member invite.
Args:
- room_id: The ID of the room.
-
- event_dict (dict[str, Any]): Dictionary containing the event body.
+ event_dict: Dictionary containing the event body.
"""
- room_version = await self.store.get_room_version_id(room_id)
+ assert_params_in_dict(event_dict, ["room_id"])
+ room_version = await self.store.get_room_version_id(event_dict["room_id"])
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index be8562d47b..4bfd8d5617 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -38,7 +38,12 @@ from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ JsonDict,
+ UserID,
+ contains_invalid_mxid_characters,
+ map_username_to_mxid_localpart,
+)
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -885,10 +890,12 @@ class OidcHandler(BaseHandler):
"Retrieved user attributes from user mapping provider: %r", attributes
)
- if not attributes["localpart"]:
- raise MappingException("localpart is empty")
-
- localpart = map_username_to_mxid_localpart(attributes["localpart"])
+ localpart = attributes["localpart"]
+ if not localpart:
+ raise MappingException(
+ "Error parsing OIDC response: OIDC mapping provider plugin "
+ "did not return a localpart value"
+ )
user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
@@ -908,6 +915,11 @@ class OidcHandler(BaseHandler):
# This mxid is taken
raise MappingException("mxid '{}' is already taken".format(user_id))
else:
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(localpart):
+ raise MappingException("localpart is invalid: %s" % (localpart,))
+
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
@@ -1076,6 +1088,9 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
) -> UserAttribute:
localpart = self._config.localpart_template.render(user=userinfo).strip()
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c242c409cf..153cbae7b9 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -158,7 +158,8 @@ class ReceiptEventSource:
if from_key == to_key:
return [], to_key
- # We first need to fetch all new receipts
+ # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
+ # by most recent.
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index aee772239a..5d9b555b13 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -31,6 +31,7 @@ from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
+ contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
@@ -58,6 +59,7 @@ class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+ self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -100,7 +102,7 @@ class SamlHandler(BaseHandler):
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url
+ entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
@@ -317,6 +319,11 @@ class SamlHandler(BaseHandler):
"Unable to generate a Matrix ID from the SAML response"
)
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(localpart):
+ raise MappingException("localpart is invalid: %s" % (localpart,))
+
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index ca7917c989..1e7949a323 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -278,7 +278,8 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
- """Get receipts for all rooms between two stream_ids.
+ """Get receipts for all rooms between two stream_ids, up
+ to a limit of the latest 100 read receipts.
Args:
to_key: Max stream id to fetch receipts upto.
@@ -294,12 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
"""
txn.execute(sql, [to_key])
diff --git a/synapse/types.py b/synapse/types.py
index 5e52c1eb12..7c9f716804 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -318,14 +318,14 @@ mxid_localpart_allowed_characters = set(
)
-def contains_invalid_mxid_characters(localpart):
+def contains_invalid_mxid_characters(localpart: str) -> bool:
"""Check for characters not allowed in an mxid or groupid localpart
Args:
- localpart (basestring): the localpart to be checked
+ localpart: the localpart to be checked
Returns:
- bool: True if there are any naughty characters
+ True if there are any naughty characters
"""
return any(c not in mxid_localpart_allowed_characters for c in localpart)
|