summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/handlers/deactivate_account.py3
-rw-r--r--synapse/handlers/profile.py67
-rw-r--r--synapse/handlers/register.py26
-rw-r--r--synapse/handlers/room.py77
-rw-r--r--synapse/handlers/room_member.py6
-rw-r--r--synapse/handlers/search.py45
-rw-r--r--synapse/handlers/sync.py5
8 files changed, 215 insertions, 72 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bd1a322563..e32c93e234 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2060,6 +2060,10 @@ CHECK_AUTH_CALLBACK = Callable[
         Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
     ],
 ]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 
 
 class PasswordAuthProvider:
@@ -2072,6 +2076,9 @@ class PasswordAuthProvider:
         # lists of callbacks
         self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
         self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+        self.get_username_for_registration_callbacks: List[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
 
         # Mapping from login type to login parameters
         self._supported_login_types: Dict[str, Iterable[str]] = {}
@@ -2086,6 +2093,9 @@ class PasswordAuthProvider:
         auth_checkers: Optional[
             Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
         ] = None,
+        get_username_for_registration: Optional[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2130,6 +2140,11 @@ class PasswordAuthProvider:
                 # Add the new method to the list of auth_checker_callbacks for this login type
                 self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
+        if get_username_for_registration is not None:
+            self.get_username_for_registration_callbacks.append(
+                get_username_for_registration,
+            )
+
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
 
@@ -2285,3 +2300,46 @@ class PasswordAuthProvider:
             except Exception as e:
                 logger.warning("Failed to run module API callback %s: %s", callback, e)
                 continue
+
+    async def get_username_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the username to use when registering the user, using the credentials
+        and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            The localpart to use when registering this user, or None if no module
+            returned a localpart.
+        """
+        for callback in self.get_username_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_username_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_username_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index bee62cf360..7a13d76a68 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -157,6 +157,9 @@ class DeactivateAccountHandler:
         # Mark the user as deactivated.
         await self.store.set_user_deactivated_status(user_id, True)
 
+        # Remove account data (including ignored users and push rules).
+        await self.store.purge_account_data_for_user(user_id)
+
         return identity_server_supports_unbinding
 
     async def _reject_pending_invites_for_user(self, user_id: str) -> None:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6b5a6ded8b..36e3ad2ba9 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -31,6 +31,8 @@ from synapse.types import (
     create_requester,
     get_domain_from_id,
 )
+from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -64,6 +66,11 @@ class ProfileHandler:
         self.user_directory_handler = hs.get_user_directory_handler()
         self.request_ratelimiter = hs.get_request_ratelimiter()
 
+        self.max_avatar_size = hs.config.server.max_avatar_size
+        self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
+
+        self.server_name = hs.config.server.server_name
+
         if hs.config.worker.run_background_tasks:
             self.clock.looping_call(
                 self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
@@ -286,6 +293,9 @@ class ProfileHandler:
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
+        if not await self.check_avatar_size_and_mime_type(new_avatar_url):
+            raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN)
+
         avatar_url_to_set: Optional[str] = new_avatar_url
         if new_avatar_url == "":
             avatar_url_to_set = None
@@ -307,6 +317,63 @@ class ProfileHandler:
 
         await self._update_join_states(requester, target_user)
 
+    @cached()
+    async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
+        """Check that the size and content type of the avatar at the given MXC URI are
+        within the configured limits.
+
+        Args:
+            mxc: The MXC URI at which the avatar can be found.
+
+        Returns:
+             A boolean indicating whether the file can be allowed to be set as an avatar.
+        """
+        if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
+            return True
+
+        server_name, _, media_id = parse_and_validate_mxc_uri(mxc)
+
+        if server_name == self.server_name:
+            media_info = await self.store.get_local_media(media_id)
+        else:
+            media_info = await self.store.get_cached_remote_media(server_name, media_id)
+
+        if media_info is None:
+            # Both configuration options need to access the file's metadata, and
+            # retrieving remote avatars just for this becomes a bit of a faff, especially
+            # if e.g. the file is too big. It's also generally safe to assume most files
+            # used as avatar are uploaded locally, or if the upload didn't happen as part
+            # of a PUT request on /avatar_url that the file was at least previewed by the
+            # user locally (and therefore downloaded to the remote media cache).
+            logger.warning("Forbidding avatar change to %s: avatar not on server", mxc)
+            return False
+
+        if self.max_avatar_size:
+            # Ensure avatar does not exceed max allowed avatar size
+            if media_info["media_length"] > self.max_avatar_size:
+                logger.warning(
+                    "Forbidding avatar change to %s: %d bytes is above the allowed size "
+                    "limit",
+                    mxc,
+                    media_info["media_length"],
+                )
+                return False
+
+        if self.allowed_avatar_mimetypes:
+            # Ensure the avatar's file type is allowed
+            if (
+                self.allowed_avatar_mimetypes
+                and media_info["media_type"] not in self.allowed_avatar_mimetypes
+            ):
+                logger.warning(
+                    "Forbidding avatar change to %s: mimetype %s not allowed",
+                    mxc,
+                    media_info["media_type"],
+                )
+                return False
+
+        return True
+
     async def on_profile_query(self, args: JsonDict) -> JsonDict:
         """Handles federation profile query requests."""
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f08a516a75..a719d5eef3 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -132,6 +132,7 @@ class RegistrationHandler:
         localpart: str,
         guest_access_token: Optional[str] = None,
         assigned_user_id: Optional[str] = None,
+        inhibit_user_in_use_error: bool = False,
     ) -> None:
         if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
@@ -171,21 +172,22 @@ class RegistrationHandler:
 
         users = await self.store.get_users_by_id_case_insensitive(user_id)
         if users:
-            if not guest_access_token:
+            if not inhibit_user_in_use_error and not guest_access_token:
                 raise SynapseError(
                     400, "User ID already taken.", errcode=Codes.USER_IN_USE
                 )
-            user_data = await self.auth.get_user_by_access_token(guest_access_token)
-            if (
-                not user_data.is_guest
-                or UserID.from_string(user_data.user_id).localpart != localpart
-            ):
-                raise AuthError(
-                    403,
-                    "Cannot register taken user ID without valid guest "
-                    "credentials for that user.",
-                    errcode=Codes.FORBIDDEN,
-                )
+            if guest_access_token:
+                user_data = await self.auth.get_user_by_access_token(guest_access_token)
+                if (
+                    not user_data.is_guest
+                    or UserID.from_string(user_data.user_id).localpart != localpart
+                ):
+                    raise AuthError(
+                        403,
+                        "Cannot register taken user ID without valid guest "
+                        "credentials for that user.",
+                        errcode=Codes.FORBIDDEN,
+                    )
 
         if guest_access_token is None:
             try:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f963078e59..1420d67729 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -30,6 +30,7 @@ from typing import (
     Tuple,
 )
 
+import attr
 from typing_extensions import TypedDict
 
 from synapse.api.constants import (
@@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.federation import get_domains_from_state
 from synapse.rest.admin._base import assert_user_is_admin
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
@@ -90,6 +92,17 @@ id_server_scheme = "https://"
 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventContext:
+    events_before: List[EventBase]
+    event: EventBase
+    events_after: List[EventBase]
+    state: List[EventBase]
+    aggregations: Dict[str, BundledAggregations]
+    start: str
+    end: str
+
+
 class RoomCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ class RoomContextHandler:
         limit: int,
         event_filter: Optional[Filter],
         use_admin_priviledge: bool = False,
-    ) -> Optional[JsonDict]:
+    ) -> Optional[EventContext]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -1167,38 +1180,28 @@ class RoomContextHandler:
         results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
+        events_before = results.events_before
+        events_after = results.events_after
 
         if event_filter:
-            results["events_before"] = await event_filter.filter(
-                results["events_before"]
-            )
-            results["events_after"] = await event_filter.filter(results["events_after"])
+            events_before = await event_filter.filter(events_before)
+            events_after = await event_filter.filter(events_after)
 
-        results["events_before"] = await filter_evts(results["events_before"])
-        results["events_after"] = await filter_evts(results["events_after"])
+        events_before = await filter_evts(events_before)
+        events_after = await filter_evts(events_after)
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
-        results["event"] = filtered[0]
+        event = filtered[0]
 
         # Fetch the aggregations.
         aggregations = await self.store.get_bundled_aggregations(
-            [results["event"]], user.to_string()
+            itertools.chain(events_before, (event,), events_after),
+            user.to_string(),
         )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_before"], user.to_string()
-            )
-        )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_after"], user.to_string()
-            )
-        )
-        results["aggregations"] = aggregations
 
-        if results["events_after"]:
-            last_event_id = results["events_after"][-1].event_id
+        if events_after:
+            last_event_id = events_after[-1].event_id
         else:
             last_event_id = event_id
 
@@ -1206,9 +1209,9 @@ class RoomContextHandler:
             state_filter = StateFilter.from_lazy_load_member_list(
                 ev.sender
                 for ev in itertools.chain(
-                    results["events_before"],
-                    (results["event"],),
-                    results["events_after"],
+                    events_before,
+                    (event,),
+                    events_after,
                 )
             )
         else:
@@ -1226,21 +1229,23 @@ class RoomContextHandler:
         if event_filter:
             state_events = await event_filter.filter(state_events)
 
-        results["state"] = await filter_evts(state_events)
-
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
         token = StreamToken.START
 
-        results["start"] = await token.copy_and_replace(
-            "room_key", results["start"]
-        ).to_string(self.store)
-
-        results["end"] = await token.copy_and_replace(
-            "room_key", results["end"]
-        ).to_string(self.store)
-
-        return results
+        return EventContext(
+            events_before=events_before,
+            event=event,
+            events_after=events_after,
+            state=await filter_evts(state_events),
+            aggregations=aggregations,
+            start=await token.copy_and_replace("room_key", results.start).to_string(
+                self.store
+            ),
+            end=await token.copy_and_replace("room_key", results.end).to_string(
+                self.store
+            ),
+        )
 
 
 class TimestampLookupHandler:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6aa910dd10..3dd5e1b6e4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -590,6 +590,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 errcode=Codes.BAD_JSON,
             )
 
+        if "avatar_url" in content:
+            if not await self.profile_handler.check_avatar_size_and_mime_type(
+                content["avatar_url"],
+            ):
+                raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN)
+
         # The event content should *not* include the authorising user as
         # it won't be properly signed. Strip it out since it might come
         # back from a client updating a display name / avatar.
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0b153a6822..02bb5ae72f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -361,36 +361,37 @@ class SearchHandler:
 
                 logger.info(
                     "Context for search returned %d and %d events",
-                    len(res["events_before"]),
-                    len(res["events_after"]),
+                    len(res.events_before),
+                    len(res.events_after),
                 )
 
-                res["events_before"] = await filter_events_for_client(
-                    self.storage, user.to_string(), res["events_before"]
+                events_before = await filter_events_for_client(
+                    self.storage, user.to_string(), res.events_before
                 )
 
-                res["events_after"] = await filter_events_for_client(
-                    self.storage, user.to_string(), res["events_after"]
+                events_after = await filter_events_for_client(
+                    self.storage, user.to_string(), res.events_after
                 )
 
-                res["start"] = await now_token.copy_and_replace(
-                    "room_key", res["start"]
-                ).to_string(self.store)
-
-                res["end"] = await now_token.copy_and_replace(
-                    "room_key", res["end"]
-                ).to_string(self.store)
+                context = {
+                    "events_before": events_before,
+                    "events_after": events_after,
+                    "start": await now_token.copy_and_replace(
+                        "room_key", res.start
+                    ).to_string(self.store),
+                    "end": await now_token.copy_and_replace(
+                        "room_key", res.end
+                    ).to_string(self.store),
+                }
 
                 if include_profile:
                     senders = {
                         ev.sender
-                        for ev in itertools.chain(
-                            res["events_before"], [event], res["events_after"]
-                        )
+                        for ev in itertools.chain(events_before, [event], events_after)
                     }
 
-                    if res["events_after"]:
-                        last_event_id = res["events_after"][-1].event_id
+                    if events_after:
+                        last_event_id = events_after[-1].event_id
                     else:
                         last_event_id = event.event_id
 
@@ -402,7 +403,7 @@ class SearchHandler:
                         last_event_id, state_filter
                     )
 
-                    res["profile_info"] = {
+                    context["profile_info"] = {
                         s.state_key: {
                             "displayname": s.content.get("displayname", None),
                             "avatar_url": s.content.get("avatar_url", None),
@@ -411,7 +412,7 @@ class SearchHandler:
                         if s.type == EventTypes.Member and s.state_key in senders
                     }
 
-                contexts[event.event_id] = res
+                contexts[event.event_id] = context
         else:
             contexts = {}
 
@@ -421,10 +422,10 @@ class SearchHandler:
 
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now
+                context["events_before"], time_now  # type: ignore[arg-type]
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now
+                context["events_after"], time_now  # type: ignore[arg-type]
             )
 
         state_results = {}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ffc6b748e8..c72ed7c290 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -100,7 +101,7 @@ class TimelineBatch:
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
-    bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
+    bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -1619,7 +1620,7 @@ class SyncHandler:
         # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
         ignored_account_data = (
             await self.store.get_global_account_data_by_type_for_user(
-                AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+                user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
             )
         )