summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py145
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/handlers/saml_handler.py198
-rw-r--r--synapse/handlers/search.py34
5 files changed, 325 insertions, 76 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bc26921768..cf9c46d027 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import unwrapFirstError
+from synapse.util import batch_iter, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -182,7 +183,7 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+        logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
         existing = yield self.store.get_event(
@@ -278,9 +279,15 @@ class FederationHandler(BaseHandler):
                             len(missing_prevs),
                         )
 
-                        yield self._get_missing_events_for_pdu(
-                            origin, pdu, prevs, min_depth
-                        )
+                        try:
+                            yield self._get_missing_events_for_pdu(
+                                origin, pdu, prevs, min_depth
+                            )
+                        except Exception as e:
+                            raise Exception(
+                                "Error fetching missing prev_events for %s: %s"
+                                % (event_id, e)
+                            )
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
@@ -292,14 +299,6 @@ class FederationHandler(BaseHandler):
                                 room_id,
                                 event_id,
                             )
-                elif missing_prevs:
-                    logger.info(
-                        "[%s %s] Not recursively fetching %d missing prev_events: %s",
-                        room_id,
-                        event_id,
-                        len(missing_prevs),
-                        shortstr(missing_prevs),
-                    )
 
             if prevs - seen:
                 # We've still not been able to get all of the prev_events for this event.
@@ -344,6 +343,12 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
+                logger.info(
+                    "Event %s is missing prev_events: calculating state for a "
+                    "backwards extremity",
+                    event_id,
+                )
+
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
                 auth_chains = set()
@@ -364,10 +369,7 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         logger.info(
-                            "[%s %s] Requesting state at missing prev_event %s",
-                            room_id,
-                            event_id,
-                            p,
+                            "Requesting state at missing prev_event %s", event_id,
                         )
 
                         room_version = yield self.store.get_room_version(room_id)
@@ -379,11 +381,9 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self.federation_client.get_state_for_room(
-                                origin, room_id, p
-                            )
+                            ) = yield self._get_state_for_room(origin, room_id, p)
 
-                            # we want the state *after* p; get_state_for_room returns the
+                            # we want the state *after* p; _get_state_for_room returns the
                             # state *before* p.
                             remote_event = yield self.federation_client.get_pdu(
                                 [origin], p, room_version, outlier=True
@@ -425,7 +425,7 @@ class FederationHandler(BaseHandler):
                     evs = yield self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
-                        check_redacted=False,
+                        redact_behaviour=EventRedactBehaviour.AS_IS,
                     )
                     event_map.update(evs)
 
@@ -584,6 +584,97 @@ class FederationHandler(BaseHandler):
                         raise
 
     @defer.inlineCallbacks
+    @log_function
+    def _get_state_for_room(self, destination, room_id, event_id):
+        """Requests all of the room state at a given event from a remote homeserver.
+
+        Args:
+            destination (str): The remote homeserver to query for the state.
+            room_id (str): The id of the room we're interested in.
+            event_id (str): The id of the event we want the state at.
+
+        Returns:
+            Deferred[Tuple[List[EventBase], List[EventBase]]]:
+                A list of events in the state, and a list of events in the auth chain
+                for the given event.
+        """
+        (
+            state_event_ids,
+            auth_event_ids,
+        ) = yield self.federation_client.get_room_state_ids(
+            destination, room_id, event_id=event_id
+        )
+
+        desired_events = set(state_event_ids + auth_event_ids)
+        event_map = yield self._get_events_from_store_or_dest(
+            destination, room_id, desired_events
+        )
+
+        failed_to_fetch = desired_events - event_map.keys()
+        if failed_to_fetch:
+            logger.warning(
+                "Failed to fetch missing state/auth events for %s %s",
+                event_id,
+                failed_to_fetch,
+            )
+
+        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+
+        auth_chain.sort(key=lambda e: e.depth)
+
+        return pdus, auth_chain
+
+    @defer.inlineCallbacks
+    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+        """Fetch events from a remote destination, checking if we already have them.
+
+        Args:
+            destination (str)
+            room_id (str)
+            event_ids (Iterable[str])
+
+        Returns:
+            Deferred[dict[str, EventBase]]: A deferred resolving to a map
+            from event_id to event
+        """
+        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+
+        missing_events = set(event_ids) - fetched_events.keys()
+
+        if not missing_events:
+            return fetched_events
+
+        logger.debug(
+            "Fetching unknown state/auth events %s for room %s",
+            missing_events,
+            event_ids,
+        )
+
+        room_version = yield self.store.get_room_version(room_id)
+
+        # XXX 20 requests at once? really?
+        for batch in batch_iter(missing_events, 20):
+            deferreds = [
+                run_in_background(
+                    self.federation_client.get_pdu,
+                    destinations=[destination],
+                    event_id=e_id,
+                    room_version=room_version,
+                )
+                for e_id in batch
+            ]
+
+            res = yield make_deferred_yieldable(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
+            for success, result in res:
+                if success and result:
+                    fetched_events[result.event_id] = result
+
+        return fetched_events
+
+    @defer.inlineCallbacks
     def _process_received_pdu(self, origin, event, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
@@ -723,7 +814,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.federation_client.get_state_for_room(
+            state, auth = yield self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
@@ -911,7 +1002,9 @@ class FederationHandler(BaseHandler):
         forward_events = yield self.store.get_successor_events(list(extremities))
 
         extremities_events = yield self.store.get_events(
-            forward_events, check_redacted=False, get_prev_content=False
+            forward_events,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            get_prev_content=False,
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
@@ -1210,7 +1303,7 @@ class FederationHandler(BaseHandler):
             # Check whether this room is the result of an upgrade of a room we already know
             # about. If so, migrate over user information
             predecessor = yield self.store.get_room_predecessor(room_id)
-            if not predecessor:
+            if not predecessor or not isinstance(predecessor.get("room_id"), str):
                 return
             old_room_id = predecessor["room_id"]
             logger.debug(
@@ -1453,7 +1546,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts, room_id, user_id, "leave", content=content,
+            target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..73c110a92b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = SnapshotCache()
+        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
@@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler):
             as_client_event,
             include_archived,
         )
-        now_ms = self.clock.time_msec()
-        result = self.snapshot_cache.get(now_ms, key)
-        if result is not None:
-            return result
 
-        return self.snapshot_cache.set(
-            now_ms,
+        return self.snapshot_cache.wrap(
             key,
-            self._snapshot_all_rooms(
-                user_id, pagin_config, as_client_event, include_archived
-            ),
+            self._snapshot_all_rooms,
+            user_id,
+            pagin_config,
+            as_client_event,
+            include_archived,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..bf9add7fe2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
 from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
             if event.type == EventTypes.Redaction:
                 original_event = yield self.store.get_event(
                     event.redacts,
-                    check_redacted=False,
+                    redact_behaviour=EventRedactBehaviour.AS_IS,
                     get_prev_content=False,
                     allow_rejected=False,
                     allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
         if event.type == EventTypes.Redaction:
             original_event = yield self.store.get_event(
                 event.redacts,
-                check_redacted=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
                 get_prev_content=False,
                 allow_rejected=False,
                 allow_none=True,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.storage.state import StateFilter
 from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self.auth = hs.get_auth()
 
     @defer.inlineCallbacks
     def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
             room_id (str): id of the room to search through.
 
         Returns:
-            Deferred[iterable[unicode]]: predecessor room ids
+            Deferred[iterable[str]]: predecessor room ids
         """
 
         historical_room_ids = []
 
-        while True:
-            predecessor = yield self.store.get_room_predecessor(room_id)
+        # The initial room must have been known for us to get this far
+        predecessor = yield self.store.get_room_predecessor(room_id)
 
-            # If no predecessor, assume we've hit a dead end
+        while True:
             if not predecessor:
+                # We have reached the end of the chain of predecessors
+                break
+
+            if not isinstance(predecessor.get("room_id"), str):
+                # This predecessor object is malformed. Exit here
+                break
+
+            predecessor_room_id = predecessor["room_id"]
+
+            # Don't add it to the list until we have checked that we are in the room
+            try:
+                next_predecessor_room = yield self.store.get_room_predecessor(
+                    predecessor_room_id
+                )
+            except NotFoundError:
+                # The predecessor is not a known room, so we are done here
                 break
 
-            # Add predecessor's room ID
-            historical_room_ids.append(predecessor["room_id"])
+            historical_room_ids.append(predecessor_room_id)
 
-            # Scan through the old room for further predecessors
-            room_id = predecessor["room_id"]
+            # And repeat
+            predecessor = next_predecessor_room
 
         return historical_room_ids