summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/admin.py89
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/federation.py10
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/pagination.py15
-rw-r--r--synapse/handlers/register.py14
-rw-r--r--synapse/handlers/room.py24
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/saml_handler.py62
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/handlers/typing.py2
12 files changed, 105 insertions, 121 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 76d18a8ba8..9205865231 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,9 +14,11 @@
 # limitations under the License.
 
 import logging
+from typing import List
 
 from synapse.api.constants import Membership
-from synapse.types import RoomStreamToken
+from synapse.events import FrozenEvent
+from synapse.types import RoomStreamToken, StateMap
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -60,68 +62,6 @@ class AdminHandler(BaseHandler):
             ret["avatar_url"] = profile.avatar_url
         return ret
 
-    async def get_users(self):
-        """Function to retrieve a list of users in users table.
-
-        Args:
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        ret = await self.store.get_users()
-
-        return ret
-
-    async def get_users_paginate(self, start, limit, name, guests, deactivated):
-        """Function to retrieve a paginated list of users from
-        users list. This will return a json list of users.
-
-        Args:
-            start (int): start number to begin the query from
-            limit (int): number of rows to retrieve
-            name (string): filter for user names
-            guests (bool): whether to in include guest users
-            deactivated (bool): whether to include deactivated users
-        Returns:
-            defer.Deferred: resolves to json list[dict[str, Any]]
-        """
-        ret = await self.store.get_users_paginate(
-            start, limit, name, guests, deactivated
-        )
-
-        return ret
-
-    async def search_users(self, term):
-        """Function to search users list for one or more users with
-        the matched term.
-
-        Args:
-            term (str): search term
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        ret = await self.store.search_users(term)
-
-        return ret
-
-    def get_user_server_admin(self, user):
-        """
-        Get the admin bit on a user.
-
-        Args:
-            user_id (UserID): the (necessarily local) user to manipulate
-        """
-        return self.store.is_server_admin(user)
-
-    def set_user_server_admin(self, user, admin):
-        """
-        Set the admin bit on a user.
-
-        Args:
-            user_id (UserID): the (necessarily local) user to manipulate
-            admin (bool): whether or not the user should be an admin of this server
-        """
-        return self.store.set_server_admin(user, admin)
-
     async def export_user_data(self, user_id, writer):
         """Write all data we have on the user to the given writer.
 
@@ -134,7 +74,7 @@ class AdminHandler(BaseHandler):
             The returned value is that returned by `writer.finished()`.
         """
         # Get all rooms the user is in or has been in
-        rooms = await self.store.get_rooms_for_user_where_membership_is(
+        rooms = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id,
             membership_list=(
                 Membership.JOIN,
@@ -259,35 +199,26 @@ class ExfiltrationWriter(object):
     """Interface used to specify how to write exported data.
     """
 
-    def write_events(self, room_id, events):
+    def write_events(self, room_id: str, events: List[FrozenEvent]):
         """Write a batch of events for a room.
-
-        Args:
-            room_id (str)
-            events (list[FrozenEvent])
         """
         pass
 
-    def write_state(self, room_id, event_id, state):
+    def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
         """Write the state at the given event in the room.
 
         This only gets called for backward extremities rather than for each
         event.
-
-        Args:
-            room_id (str)
-            event_id (str)
-            state (dict[tuple[str, str], FrozenEvent])
         """
         pass
 
-    def write_invite(self, room_id, event, state):
+    def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
         """Write an invite for the room, with associated invite state.
 
         Args:
-            room_id (str)
-            event (FrozenEvent)
-            state (dict[tuple[str, str], dict]): A subset of the state at the
+            room_id
+            event
+            state: A subset of the state at the
                 invite, with a subset of the event keys (type, state_key
                 content and sender)
         """
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4426967f88..2afb390a92 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
             user_id (str): The user ID to reject pending invites for.
         """
         user = UserID.from_string(user_id)
-        pending_invites = await self.store.get_invited_rooms_for_user(user_id)
+        pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
 
         for room in pending_invites:
             try:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 61b6713c88..d4f9a792fc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import StateMap, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -89,7 +89,7 @@ class _NewEventInfo:
 
     event = attr.ib(type=EventBase)
     state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
-    auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
 
 
 def shortstr(iterable, maxitems=5):
@@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
                     ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(
-                        ours.values()
-                    )  # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())  # type: list[StateMap[str]]
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-        auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+        auth_events: Optional[StateMap[EventBase]],
         backfilled: bool,
     ):
         """
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 44ec3e66ae..2e6755f19c 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
         if include_archived:
             memberships.append(Membership.LEAVE)
 
-        room_list = await self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id, membership_list=memberships
         )
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 00a6afc963..71d76202c9 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -88,6 +88,8 @@ class PaginationHandler(object):
         if hs.config.retention_enabled:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention_purge_jobs:
+                logger.info("Setting up purge job with config: %s", job)
+
                 self.clock.looping_call(
                     run_as_background_process,
                     job["interval"],
@@ -130,11 +132,22 @@ class PaginationHandler(object):
         else:
             include_null = False
 
+        logger.info(
+            "[purge] Running purge job for %d < max_lifetime <= %d (include NULLs = %s)",
+            min_ms,
+            max_ms,
+            include_null,
+        )
+
         rooms = yield self.store.get_rooms_for_retention_period_in_range(
             min_ms, max_ms, include_null
         )
 
+        logger.debug("[purge] Rooms to purge: %s", rooms)
+
         for room_id, retention_policy in iteritems(rooms):
+            logger.info("[purge] Attempting to purge messages in room %s", room_id)
+
             if room_id in self._purges_in_progress_by_room:
                 logger.warning(
                     "[purge] not purging room %s as there's an ongoing purge running"
@@ -156,7 +169,7 @@ class PaginationHandler(object):
 
             stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
 
-            r = yield self.store.get_room_event_after_stream_ordering(
+            r = yield self.store.get_room_event_before_stream_ordering(
                 room_id, stream_ordering,
             )
             if not r:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8a7d965feb..7ffc194f0c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -20,13 +20,7 @@ from twisted.internet import defer
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, LoginType
-from synapse.api.errors import (
-    AuthError,
-    Codes,
-    ConsentNotGivenError,
-    RegistrationError,
-    SynapseError,
-)
+from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -165,7 +159,7 @@ class RegistrationHandler(BaseHandler):
         Returns:
             Deferred[str]: user_id
         Raises:
-            RegistrationError if there was a problem registering.
+            SynapseError if there was a problem registering.
         """
         yield self.check_registration_ratelimit(address)
 
@@ -174,7 +168,7 @@ class RegistrationHandler(BaseHandler):
         if password:
             password_hash = yield self._auth_handler.hash(password)
 
-        if localpart:
+        if localpart is not None:
             yield self.check_username(localpart, guest_access_token=guest_access_token)
 
             was_guest = guest_access_token is not None
@@ -182,7 +176,7 @@ class RegistrationHandler(BaseHandler):
             if not was_guest:
                 try:
                     int(localpart)
-                    raise RegistrationError(
+                    raise SynapseError(
                         400, "Numeric user IDs are reserved for guest users."
                     )
                 except ValueError:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9cab2adbfb..9f50196ea7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    Requester,
+    RoomAlias,
+    RoomID,
+    RoomStreamToken,
+    StateMap,
+    StreamToken,
+    UserID,
+)
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
@@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _update_upgraded_room_pls(
-        self, requester, old_room_id, new_room_id, old_room_state,
+        self,
+        requester: Requester,
+        old_room_id: str,
+        new_room_id: str,
+        old_room_state: StateMap[str],
     ):
         """Send updated power levels in both rooms after an upgrade
 
         Args:
-            requester (synapse.types.Requester): the user requesting the upgrade
-            old_room_id (str): the id of the room to be replaced
-            new_room_id (str): the id of the replacement room
-            old_room_state (dict[tuple[str, str], str]): the state map for the old room
+            requester: the user requesting the upgrade
+            old_room_id: the id of the room to be replaced
+            new_room_id: the id of the replacement room
+            old_room_state: the state map for the old room
 
         Returns:
             Deferred
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 03bb52ccfb..15e8aa5249 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -690,7 +690,7 @@ class RoomMemberHandler(object):
 
     @defer.inlineCallbacks
     def _get_inviter(self, user_id, room_id):
-        invite = yield self.store.get_invite_for_user_in_room(
+        invite = yield self.store.get_invite_for_local_user_in_room(
             user_id=user_id, room_id=room_id
         )
         if invite:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 0082f85c26..7f411b53b9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,6 +24,7 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
+from synapse.module_api import ModuleApi
 from synapse.rest.client.v1.login import SSOAuthHandler
 from synapse.types import (
     UserID,
@@ -31,6 +32,7 @@ from synapse.types import (
     mxid_localpart_allowed_characters,
 )
 from synapse.util.async_helpers import Linearizer
+from synapse.util.iterutils import chunk_seq
 
 logger = logging.getLogger(__name__)
 
@@ -59,7 +61,8 @@ class SamlHandler:
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
-            hs.config.saml2_user_mapping_provider_config
+            hs.config.saml2_user_mapping_provider_config,
+            ModuleApi(hs, hs.get_auth_handler()),
         )
 
         # identifier for the external_ids table
@@ -112,10 +115,10 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        user_id = await self._map_saml_response_to_user(resp_bytes)
+        user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
         self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
 
-    async def _map_saml_response_to_user(self, resp_bytes):
+    async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -130,17 +133,28 @@ class SamlHandler:
             logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
-        logger.info("SAML2 response: %s", saml2_auth.origxml)
-        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+        logger.debug("SAML2 response: %s", saml2_auth.origxml)
+        for assertion in saml2_auth.assertions:
+            # kibana limits the length of a log field, whereas this is all rather
+            # useful, so split it up.
+            count = 0
+            for part in chunk_seq(str(assertion), 10000):
+                logger.info(
+                    "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
+                )
+                count += 1
 
-        try:
-            remote_user_id = saml2_auth.ava["uid"][0]
-        except KeyError:
-            logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise Exception("Failed to extract remote user id from SAML response")
+
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -183,7 +197,7 @@ class SamlHandler:
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
                 attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i
+                    saml2_auth, i, client_redirect_url=client_redirect_url,
                 )
 
                 logger.debug(
@@ -216,6 +230,8 @@ class SamlHandler:
                     500, "Unable to generate a Matrix ID from the SAML response"
                 )
 
+            logger.info("Mapped SAML user to local part %s", localpart)
+
             registered_user_id = await self._registration_handler.register_user(
                 localpart=localpart, default_display_name=displayname
             )
@@ -265,17 +281,35 @@ class SamlConfig(object):
 class DefaultSamlMappingProvider(object):
     __version__ = "0.0.1"
 
-    def __init__(self, parsed_config: SamlConfig):
+    def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
         """The default SAML user mapping provider
 
         Args:
             parsed_config: Module configuration
+            module_api: module api proxy
         """
         self._mxid_source_attribute = parsed_config.mxid_source_attribute
         self._mxid_mapper = parsed_config.mxid_mapper
 
+        self._grandfathered_mxid_source_attribute = (
+            module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+        )
+
+    def get_remote_user_id(
+        self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+    ):
+        """Extracts the remote user id from the SAML response"""
+        try:
+            return saml_response.ava["uid"][0]
+        except KeyError:
+            logger.warning("SAML2 response lacks a 'uid' attestation")
+            raise SynapseError(400, "'uid' not in SAML2 response")
+
     def saml_response_to_user_attributes(
-        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        failures: int,
+        client_redirect_url: str,
     ) -> dict:
         """Maps some text from a SAML response to attributes of a new user
 
@@ -285,6 +319,8 @@ class DefaultSamlMappingProvider(object):
             failures: How many times a call to this function with this
                 saml_response has resulted in a failure
 
+            client_redirect_url: where the client wants to redirect to
+
         Returns:
             dict: A dict containing new user attributes. Possible keys:
                 * mxid_localpart (str): Required. The localpart of the user's mxid
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ef750d1497..110097eab9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
         search_filter = Filter(filter_dict)
 
         # TODO: Search through left rooms too
-        rooms = yield self.store.get_rooms_for_user_where_membership_is(
+        rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
             user.to_string(),
             membership_list=[Membership.JOIN],
             # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2d3b8ba73c..cd95f85e3f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1662,7 +1662,7 @@ class SyncHandler(object):
             Membership.BAN,
         )
 
-        room_list = await self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id, membership_list=membership_list
         )
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index b635c339ed..d5ca9cb07b 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -257,7 +257,7 @@ class TypingHandler(object):
             "typing_key", self._latest_room_serial, rooms=[member.room_id]
         )
 
-    def get_all_typing_updates(self, last_id, current_id):
+    async def get_all_typing_updates(self, last_id, current_id):
         if last_id == current_id:
             return []