summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py28
-rw-r--r--synapse/handlers/devicemessage.py33
-rw-r--r--synapse/handlers/directory.py63
-rw-r--r--synapse/handlers/events.py2
-rw-r--r--synapse/handlers/federation.py127
-rw-r--r--synapse/handlers/identity.py9
-rw-r--r--synapse/handlers/message.py108
-rw-r--r--synapse/handlers/presence.py311
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/room.py29
-rw-r--r--synapse/handlers/room_member.py27
-rw-r--r--synapse/handlers/space_summary.py80
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/handlers/ui_auth/checkers.py35
14 files changed, 553 insertions, 309 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py

index 36f2450e2e..8a6666a4ad 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -17,6 +17,7 @@ import logging import time import unicodedata import urllib.parse +from binascii import crc32 from typing import ( TYPE_CHECKING, Any, @@ -34,6 +35,7 @@ from typing import ( import attr import bcrypt import pymacaroons +import unpaddedbase64 from twisted.web.server import Request @@ -66,6 +68,7 @@ from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: @@ -808,10 +811,12 @@ class AuthHandler(BaseHandler): logger.info( "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry ) + target_user_id_obj = UserID.from_string(puppets_user_id) else: logger.info( "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry ) + target_user_id_obj = UserID.from_string(user_id) if ( not is_appservice_ghost @@ -819,7 +824,7 @@ class AuthHandler(BaseHandler): ): await self.auth.check_auth_blocking(user_id) - access_token = self.macaroon_gen.generate_access_token(user_id) + access_token = self.generate_access_token(target_user_id_obj) await self.store.add_access_token_to_user( user_id=user_id, token=access_token, @@ -1192,6 +1197,19 @@ class AuthHandler(BaseHandler): return None return user_id + def generate_access_token(self, for_user: UserID) -> str: + """Generates an opaque string, for use as an access token""" + + # we use the following format for access tokens: + # syt_<base64 local part>_<random string>_<base62 crc check> + + b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8")) + random_string = stringutils.random_string(20) + base = f"syt_{b64local}_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + async def validate_short_term_login_token( self, login_token: str ) -> LoginTokenAttributes: @@ -1585,10 +1603,7 @@ class MacaroonGenerator: hs = attr.ib() - def generate_access_token( - self, user_id: str, extra_caveats: Optional[List[str]] = None - ) -> str: - extra_caveats = extra_caveats or [] + def generate_guest_access_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") # Include a nonce, to make sure that each login gets a different @@ -1596,8 +1611,7 @@ class MacaroonGenerator: macaroon.add_first_party_caveat( "nonce = %s" % (stringutils.random_string_with_symbols(16),) ) - for caveat in extra_caveats: - macaroon.add_first_party_caveat(caveat) + macaroon.add_first_party_caveat("guest = true") return macaroon.serialize() def generate_short_term_login_token( diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index c5d631de07..580b941595 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict -from synapse.api.constants import EduTypes +from synapse.api.constants import ToDeviceEventTypes from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background @@ -79,6 +79,8 @@ class DeviceMessageHandler: ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + # a rate limiter for room key requests. The keys are + # (sending_user_id, sending_device_id). self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), @@ -100,12 +102,25 @@ class DeviceMessageHandler: for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", user_id) + logger.warning("To-device message to non-local user %s", user_id) raise SynapseError(400, "Not a user here") if not by_device: continue + # Ratelimit key requests by the sending user. + if message_type == ToDeviceEventTypes.RoomKeyRequest: + allowed, _ = await self._ratelimiter.can_do_action( + None, (sender_user_id, None) + ) + if not allowed: + logger.info( + "Dropping room_key_request from %s to %s due to rate limit", + sender_user_id, + user_id, + ) + continue + messages_by_device = { device_id: { "content": message_content, @@ -192,13 +207,19 @@ class DeviceMessageHandler: for user_id, by_device in messages.items(): # Ratelimit local cross-user key requests by the sending device. if ( - message_type == EduTypes.RoomKeyRequest + message_type == ToDeviceEventTypes.RoomKeyRequest and user_id != sender_user_id - and await self._ratelimiter.can_do_action( + ): + allowed, _ = await self._ratelimiter.can_do_action( requester, (sender_user_id, requester.device_id) ) - ): - continue + if not allowed: + logger.info( + "Dropping room_key_request from %s to %s due to rate limit", + sender_user_id, + user_id, + ) + continue # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 90932316f3..4064a2b859 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@ import logging import string -from typing import Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( @@ -27,15 +27,19 @@ from synapse.api.errors import ( SynapseError, ) from synapse.appservice import ApplicationService -from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.state = hs.get_state_handler() @@ -60,7 +64,7 @@ class DirectoryHandler(BaseHandler): room_id: str, servers: Optional[Iterable[str]] = None, creator: Optional[str] = None, - ): + ) -> None: # general association creation for both human users and app services for wchar in string.whitespace: @@ -74,7 +78,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) servers = {get_domain_from_id(u) for u in users} if not servers: @@ -104,8 +108,9 @@ class DirectoryHandler(BaseHandler): """ user_id = requester.user.to_string() + room_alias_str = room_alias.to_string() - if len(room_alias.to_string()) > MAX_ALIAS_LENGTH: + if len(room_alias_str) > MAX_ALIAS_LENGTH: raise SynapseError( 400, "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH, @@ -114,7 +119,7 @@ class DirectoryHandler(BaseHandler): service = requester.app_service if service: - if not service.is_interested_in_alias(room_alias.to_string()): + if not service.is_interested_in_alias(room_alias_str): raise SynapseError( 400, "This application service has not reserved this kind of alias.", @@ -138,7 +143,7 @@ class DirectoryHandler(BaseHandler): raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( - user_id, room_id, room_alias.to_string() + user_id, room_id, room_alias_str ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages @@ -211,7 +216,7 @@ class DirectoryHandler(BaseHandler): async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias - ): + ) -> None: if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -220,7 +225,7 @@ class DirectoryHandler(BaseHandler): ) await self._delete_association(room_alias) - async def _delete_association(self, room_alias: RoomAlias): + async def _delete_association(self, room_alias: RoomAlias) -> str: if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -228,17 +233,19 @@ class DirectoryHandler(BaseHandler): return room_id - async def get_association(self, room_alias: RoomAlias): + async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result = await self.get_association_from_room_alias(room_alias) + result = await self.get_association_from_room_alias( + room_alias + ) # type: Optional[RoomAliasMapping] if result: room_id = result.room_id servers = result.servers else: try: - result = await self.federation.make_query( + fed_result = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, @@ -248,13 +255,13 @@ class DirectoryHandler(BaseHandler): except CodeMessageException as e: logging.warning("Error retrieving alias") if e.code == 404: - result = None + fed_result = None else: raise - if result and "room_id" in result and "servers" in result: - room_id = result["room_id"] - servers = result["servers"] + if fed_result and "room_id" in fed_result and "servers" in fed_result: + room_id = fed_result["room_id"] + servers = fed_result["servers"] if not room_id: raise SynapseError( @@ -263,7 +270,7 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) @@ -275,7 +282,7 @@ class DirectoryHandler(BaseHandler): return {"room_id": room_id, "servers": servers} - async def on_directory_query(self, args): + async def on_directory_query(self, args: JsonDict) -> JsonDict: room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room Alias is not hosted on this homeserver") @@ -293,7 +300,7 @@ class DirectoryHandler(BaseHandler): async def _update_canonical_alias( self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias - ): + ) -> None: """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. @@ -344,7 +351,9 @@ class DirectoryHandler(BaseHandler): ratelimit=False, ) - async def get_association_from_room_alias(self, room_alias: RoomAlias): + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: result = await self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -372,7 +381,7 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return True - async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool: """Determine whether a user can delete an alias. One of the following must be true: @@ -394,14 +403,13 @@ class DirectoryHandler(BaseHandler): if not room_id: return False - res = await self.auth.check_can_change_room_list( + return await self.auth.check_can_change_room_list( room_id, UserID.from_string(user_id) ) - return res async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str - ): + ) -> None: """Edit the entry of the room in the published room list. requester @@ -469,7 +477,7 @@ class DirectoryHandler(BaseHandler): async def edit_published_appservice_room_list( self, appservice_id: str, network_id: str, room_id: str, visibility: str - ): + ) -> None: """Add or remove a room from the appservice/network specific public room list. @@ -499,5 +507,4 @@ class DirectoryHandler(BaseHandler): room_id, requester.user.to_string() ) - aliases = await self.store.get_aliases_for_room(room_id) - return aliases + return await self.store.get_aliases_for_room(room_id) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d82144d7fa..f134f1e234 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py
@@ -103,7 +103,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = await self.state.get_current_users_in_room( + users = await self.store.get_users_in_room( event.room_id ) # type: Iterable[str] else: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b59e17ad4e..6a5c33f212 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -554,8 +554,12 @@ class FederationHandler(BaseHandler): destination: str, room_id: str, event_id: str, - ) -> Tuple[List[EventBase], List[EventBase]]: - """Requests all of the room state at a given event from a remote homeserver. + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote + homeserver. + + Will also fetch any missing events reported in the `auth_chain_ids` + section of `/state_ids`. Args: destination: The remote homeserver to query for the state. @@ -563,8 +567,7 @@ class FederationHandler(BaseHandler): event_id: The id of the event we want the state at. Returns: - A list of events in the state, not including the event itself, and - a list of events in the auth chain for the given event. + A list of events in the state, not including the event itself. """ ( state_event_ids, @@ -573,68 +576,53 @@ class FederationHandler(BaseHandler): destination, room_id, event_id=event_id ) - desired_events = set(state_event_ids + auth_event_ids) - - event_map = await self._get_events_from_store_or_dest( - destination, room_id, desired_events - ) + # Fetch the state events from the DB, and check we have the auth events. + event_map = await self.store.get_events(state_event_ids, allow_rejected=True) + auth_events_in_store = await self.store.have_seen_events(auth_event_ids) - failed_to_fetch = desired_events - event_map.keys() - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s %s", - event_id, - failed_to_fetch, + # Check for missing events. We handle state and auth event seperately, + # as we want to pull the state from the DB, but we don't for the auth + # events. (Note: we likely won't use the majority of the auth chain, and + # it can be *huge* for large rooms, so it's worth ensuring that we don't + # unnecessarily pull it from the DB). + missing_state_events = set(state_event_ids) - set(event_map) + missing_auth_events = set(auth_event_ids) - set(auth_events_in_store) + if missing_state_events or missing_auth_events: + await self._get_events_and_persist( + destination=destination, + room_id=room_id, + events=missing_state_events | missing_auth_events, ) - remote_state = [ - event_map[e_id] for e_id in state_event_ids if e_id in event_map - ] - - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - auth_chain.sort(key=lambda e: e.depth) - - return remote_state, auth_chain - - async def _get_events_from_store_or_dest( - self, destination: str, room_id: str, event_ids: Iterable[str] - ) -> Dict[str, EventBase]: - """Fetch events from a remote destination, checking if we already have them. - - Persists any events we don't already have as outliers. - - If we fail to fetch any of the events, a warning will be logged, and the event - will be omitted from the result. Likewise, any events which turn out not to - be in the given room. - - This function *does not* automatically get missing auth events of the - newly fetched events. Callers must include the full auth chain of - of the missing events in the `event_ids` argument, to ensure that any - missing auth events are correctly fetched. + if missing_state_events: + new_events = await self.store.get_events( + missing_state_events, allow_rejected=True + ) + event_map.update(new_events) - Returns: - map from event_id to event - """ - fetched_events = await self.store.get_events(event_ids, allow_rejected=True) + missing_state_events.difference_update(new_events) - missing_events = set(event_ids) - fetched_events.keys() + if missing_state_events: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + missing_state_events, + ) - if missing_events: - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - room_id, - ) + if missing_auth_events: + auth_events_in_store = await self.store.have_seen_events( + missing_auth_events + ) + missing_auth_events.difference_update(auth_events_in_store) - await self._get_events_and_persist( - destination=destination, room_id=room_id, events=missing_events - ) + if missing_auth_events: + logger.warning( + "Failed to fetch missing auth events for %s %s", + event_id, + missing_auth_events, + ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - (await self.store.get_events(missing_events, allow_rejected=True)) - ) + remote_state = list(event_map.values()) # check for events which were in the wrong room. # @@ -642,8 +630,8 @@ class FederationHandler(BaseHandler): # auth_events at an event in room A are actually events in room B bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() + (event.event_id, event.room_id) + for event in remote_state if event.room_id != room_id ] @@ -660,9 +648,10 @@ class FederationHandler(BaseHandler): room_id, ) - del fetched_events[bad_event_id] + if bad_events: + remote_state = [e for e in remote_state if e.room_id == room_id] - return fetched_events + return remote_state async def _get_state_after_missing_prev_event( self, @@ -965,27 +954,23 @@ class FederationHandler(BaseHandler): # For each edge get the current state. - auth_events = {} state_events = {} events_to_state = {} for e_id in edges: - state, auth = await self._get_state_for_room( + state = await self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id, ) - auth_events.update({a.event_id: a for a in auth}) - auth_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state required_auth = { a_id - for event in events - + list(state_events.values()) - + list(auth_events.values()) + for event in events + list(state_events.values()) for a_id in event.auth_event_ids() } + auth_events = await self.store.get_events(required_auth, allow_rejected=True) auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) @@ -2640,7 +2625,9 @@ class FederationHandler(BaseHandler): # If we are going to send this event over federation we precaclculate # the joined hosts. if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event(event) + await self.event_creation_handler.cache_joined_hosts_for_event( + event, context + ) return context diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 3d686df04c..b0f192ce1c 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -17,7 +17,7 @@ """Utilities for interacting with Identity Servers""" import logging import urllib.parse -from typing import Awaitable, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple from synapse.api.errors import ( AuthError, @@ -43,11 +43,14 @@ from synapse.util.stringutils import ( from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class IdentityHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # An HTTP client for contacting trusted URLs. @@ -83,7 +86,7 @@ class IdentityHandler(BaseHandler): request: SynapseRequest, medium: str, address: str, - ): + ) -> None: """Used to ratelimit requests to `/requestToken` by IP and address. Args: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 02db753ed6..d847b863d9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -16,10 +16,11 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json +from twisted.internet import defer from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -45,14 +46,15 @@ from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator -from synapse.logging.context import run_in_background +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder, json_encoder -from synapse.util.async_helpers import Linearizer +from synapse.util import json_decoder, json_encoder, log_failure +from synapse.util.async_helpers import Linearizer, unwrapFirstError +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -68,7 +70,7 @@ logger = logging.getLogger(__name__) class MessageHandler: """Contains some read only APIs to get state about a room""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -93,7 +95,7 @@ class MessageHandler: room_id: str, event_type: str, state_key: str, - ) -> dict: + ) -> Optional[EventBase]: """Get data from a room. Args: @@ -117,6 +119,10 @@ class MessageHandler: data = await self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) + # If the membership is not JOIN, then the event ID should exist. + assert ( + membership_event_id is not None + ), "check_user_in_room_or_world_readable returned invalid data" room_state = await self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) @@ -188,10 +194,12 @@ class MessageHandler: event = last_events[0] if visible_events: - room_state = await self.state_store.get_state_for_events( + room_state_events = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) - room_state = room_state[event.event_id] + room_state = room_state_events[ + event.event_id + ] # type: Mapping[Any, EventBase] else: raise AuthError( 403, @@ -212,10 +220,14 @@ class MessageHandler: ) room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = await self.state_store.get_state_for_events( + # If the membership is not JOIN, then the event ID should exist. + assert ( + membership_event_id is not None + ), "check_user_in_room_or_world_readable returned invalid data" + room_state_events = await self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) - room_state = room_state[membership_event_id] + room_state = room_state_events[membership_event_id] now = self.clock.time_msec() events = await self._event_serializer.serialize_events( @@ -250,7 +262,7 @@ class MessageHandler: "Getting joined members after leaving is not implemented" ) - users_with_profile = await self.state.get_current_users_in_room(room_id) + users_with_profile = await self.store.get_users_in_room_with_profiles(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there @@ -449,6 +461,19 @@ class EventCreationHandler: self._external_cache = hs.get_external_cache() + # Stores the state groups we've recently added to the joined hosts + # external cache. Note that the timeout must be significantly less than + # the TTL on the external cache. + self._external_cache_joined_hosts_updates = ( + None + ) # type: Optional[ExpiringCache] + if self._external_cache.is_enabled(): + self._external_cache_joined_hosts_updates = ExpiringCache( + "_external_cache_joined_hosts_updates", + self.clock, + expiry_ms=30 * 60 * 1000, + ) + async def create_event( self, requester: Requester, @@ -957,9 +982,43 @@ class EventCreationHandler: logger.exception("Failed to encode content: %r", event.content) raise - await self.action_generator.handle_push_actions_for_event(event, context) + # We now persist the event (and update the cache in parallel, since we + # don't want to block on it). + result = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._persist_event, + requester=requester, + event=event, + context=context, + ratelimit=ratelimit, + extra_users=extra_users, + ), + run_in_background( + self.cache_joined_hosts_for_event, event, context + ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), + ], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + return result[0] + + async def _persist_event( + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: Optional[List[UserID]] = None, + ) -> EventBase: + """Actually persists the event. Should only be called by + `handle_new_client_event`, and see its docstring for documentation of + the arguments. + """ - await self.cache_joined_hosts_for_event(event) + await self.action_generator.handle_push_actions_for_event(event, context) try: # If we're a worker we need to hit out to the master. @@ -1000,7 +1059,9 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise - async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + async def cache_joined_hosts_for_event( + self, event: EventBase, context: EventContext + ) -> None: """Precalculate the joined hosts at the event, when using Redis, so that external federation senders don't have to recalculate it themselves. """ @@ -1008,6 +1069,9 @@ class EventCreationHandler: if not self._external_cache.is_enabled(): return + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None + # We actually store two mappings, event ID -> prev state group, # state group -> joined hosts, which is much more space efficient # than event ID -> joined hosts. @@ -1015,22 +1079,28 @@ class EventCreationHandler: # Note: We have to cache event ID -> prev state group, as we don't # store that in the DB. # - # Note: We always set the state group -> joined hosts cache, even if - # we already set it, so that the expiry time is reset. + # Note: We set the state group -> joined hosts cache if it hasn't been + # set for a while, so that the expiry time is reset. state_entry = await self.state.resolve_state_groups_for_events( event.room_id, event_ids=event.prev_event_ids() ) if state_entry.state_group: - joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) - await self._external_cache.set( "event_to_prev_state_group", event.event_id, state_entry.state_group, expiry_ms=60 * 60 * 1000, ) + + if state_entry.state_group in self._external_cache_joined_hosts_updates: + return + + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + # Note that the expiry times must be larger than the expiry time in + # _external_cache_joined_hosts_updates. await self._external_cache.set( "get_joined_hosts", str(state_entry.state_group), @@ -1038,6 +1108,8 @@ class EventCreationHandler: expiry_ms=60 * 60 * 1000, ) + self._external_cache_joined_hosts_updates[state_entry.state_group] = None + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 12df35f26e..6fd1f34289 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -28,6 +28,7 @@ from bisect import bisect from contextlib import contextmanager from typing import ( TYPE_CHECKING, + Callable, Collection, Dict, FrozenSet, @@ -232,23 +233,23 @@ class BasePresenceHandler(abc.ABC): """ async def update_external_syncs_row( - self, process_id, user_id, is_syncing, sync_time_msec - ): + self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + ) -> None: """Update the syncing users for an external process as a delta. This is a no-op when presence is handled by a different worker. Args: - process_id (str): An identifier for the process the users are + process_id: An identifier for the process the users are syncing against. This allows synapse to process updates as user start and stop syncing against a given process. - user_id (str): The user who has started or stopped syncing - is_syncing (bool): Whether or not the user is now syncing - sync_time_msec(int): Time in ms when the user was last syncing + user_id: The user who has started or stopped syncing + is_syncing: Whether or not the user is now syncing + sync_time_msec: Time in ms when the user was last syncing """ pass - async def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process as offline. @@ -304,7 +305,7 @@ class _NullContextManager(ContextManager[None]): class WorkerPresenceHandler(BasePresenceHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs @@ -327,7 +328,7 @@ class WorkerPresenceHandler(BasePresenceHandler): # user_id -> last_sync_ms. Lists the users that have stopped syncing but # we haven't notified the presence writer of that yet - self.users_going_offline = {} + self.users_going_offline = {} # type: Dict[str, int] self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -346,24 +347,21 @@ class WorkerPresenceHandler(BasePresenceHandler): self._on_shutdown, ) - def _on_shutdown(self): + def _on_shutdown(self) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) ) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: if self._presence_enabled: self.hs.get_tcp_replication().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id): + def mark_as_coming_online(self, user_id: str) -> None: """A user has started syncing. Send a UserSync to the presence writer, unless they had recently stopped syncing. - - Args: - user_id (str) """ going_offline = self.users_going_offline.pop(user_id, None) if not going_offline: @@ -371,18 +369,15 @@ class WorkerPresenceHandler(BasePresenceHandler): # were offline self.send_user_sync(user_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id): + def mark_as_going_offline(self, user_id: str) -> None: """A user has stopped syncing. We wait before notifying the presence writer as its likely they'll come back soon. This allows us to avoid sending a stopped syncing immediately followed by a started syncing notification to the presence writer - - Args: - user_id (str) """ self.users_going_offline[user_id] = self.clock.time_msec() - def send_stop_syncing(self): + def send_stop_syncing(self) -> None: """Check if there are any users who have stopped syncing a while ago and haven't come back yet. If there are poke the presence writer about them. """ @@ -430,7 +425,9 @@ class WorkerPresenceHandler(BasePresenceHandler): return _user_syncing() - async def notify_from_replication(self, states, stream_id): + async def notify_from_replication( + self, states: List[UserPresenceState], stream_id: int + ) -> None: parties = await get_interested_parties(self.store, self.presence_router, states) room_ids_to_states, users_to_states = parties @@ -478,7 +475,12 @@ class WorkerPresenceHandler(BasePresenceHandler): if count > 0 ] - async def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state( + self, + target_user: UserID, + state: JsonDict, + ignore_status_msg: bool = False, + ) -> None: """Set the presence state of the user.""" presence = state["presence"] @@ -508,7 +510,7 @@ class WorkerPresenceHandler(BasePresenceHandler): ignore_status_msg=ignore_status_msg, ) - async def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -592,8 +594,8 @@ class PresenceHandler(BasePresenceHandler): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} # type: Dict[int, Set[str]] - self.external_process_last_updated_ms = {} # type: Dict[int, int] + self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]] + self.external_process_last_updated_ms = {} # type: Dict[str, int] self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -633,7 +635,7 @@ class PresenceHandler(BasePresenceHandler): self._event_pos = self.store.get_current_events_token() self._event_processing = False - async def _on_shutdown(self): + async def _on_shutdown(self) -> None: """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -662,7 +664,7 @@ class PresenceHandler(BasePresenceHandler): ) logger.info("Finished _on_shutdown") - async def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self) -> None: """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -762,7 +764,7 @@ class PresenceHandler(BasePresenceHandler): states, destinations ) - async def _handle_timeouts(self): + async def _handle_timeouts(self) -> None: """Checks the presence of users that have timed out and updates as appropriate. """ @@ -814,7 +816,7 @@ class PresenceHandler(BasePresenceHandler): return await self._update_states(changes) - async def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user: UserID) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -911,17 +913,17 @@ class PresenceHandler(BasePresenceHandler): return [] async def update_external_syncs_row( - self, process_id, user_id, is_syncing, sync_time_msec - ): + self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + ) -> None: """Update the syncing users for an external process as a delta. Args: - process_id (str): An identifier for the process the users are + process_id: An identifier for the process the users are syncing against. This allows synapse to process updates as user start and stop syncing against a given process. - user_id (str): The user who has started or stopped syncing - is_syncing (bool): Whether or not the user is now syncing - sync_time_msec(int): Time in ms when the user was last syncing + user_id: The user who has started or stopped syncing + is_syncing: Whether or not the user is now syncing + sync_time_msec: Time in ms when the user was last syncing """ with (await self.external_sync_linearizer.queue(process_id)): prev_state = await self.current_state_for_user(user_id) @@ -958,7 +960,7 @@ class PresenceHandler(BasePresenceHandler): self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - async def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id: str) -> None: """Marks all users that had been marked as syncing by a given process as offline. @@ -979,12 +981,12 @@ class PresenceHandler(BasePresenceHandler): ) self.external_process_last_updated_ms.pop(process_id, None) - async def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id: str) -> UserPresenceState: """Get the current presence state for a user.""" res = await self.current_state_for_users([user_id]) return res[user_id] - async def _persist_and_notify(self, states): + async def _persist_and_notify(self, states: List[UserPresenceState]) -> None: """Persist states in the database, poke the notifier and send to interested remote servers """ @@ -1005,7 +1007,7 @@ class PresenceHandler(BasePresenceHandler): # stream (which is updated by `store.update_presence`). await self.maybe_send_presence_to_interested_destinations(states) - async def incoming_presence(self, origin, content): + async def incoming_presence(self, origin: str, content: JsonDict) -> None: """Called when we receive a `m.presence` EDU from a remote server.""" if not self._presence_enabled: return @@ -1055,7 +1057,9 @@ class PresenceHandler(BasePresenceHandler): federation_presence_counter.inc(len(updates)) await self._update_states(updates) - async def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state( + self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + ) -> None: """Set the presence state of the user.""" status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1089,7 +1093,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states([prev_state.copy_and_replace(**new_fields)]) - async def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: """Returns whether a user can see another user's presence.""" observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() @@ -1144,7 +1148,7 @@ class PresenceHandler(BasePresenceHandler): ) return rows - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when new events have happened. Handles users and servers joining rooms and require being sent presence. """ @@ -1163,7 +1167,7 @@ class PresenceHandler(BasePresenceHandler): run_as_background_process("presence.notify_new_event", _process_presence) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -1179,7 +1183,16 @@ class PresenceHandler(BasePresenceHandler): max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - await self._handle_state_delta(deltas) + + # We may get multiple deltas for different rooms, but we want to + # handle them on a room by room basis, so we batch them up by + # room. + deltas_by_room: Dict[str, List[JsonDict]] = {} + for delta in deltas: + deltas_by_room.setdefault(delta["room_id"], []).append(delta) + + for room_id, deltas_for_room in deltas_by_room.items(): + await self._handle_state_delta(room_id, deltas_for_room) self._event_pos = max_pos @@ -1188,17 +1201,21 @@ class PresenceHandler(BasePresenceHandler): max_pos ) - async def _handle_state_delta(self, deltas): - """Process current state deltas to find new joins that need to be - handled. + async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None: + """Process current state deltas for the room to find new joins that need + to be handled. """ - # A map of destination to a set of user state that they should receive - presence_destinations = {} # type: Dict[str, Set[UserPresenceState]] + + # Sets of newly joined users. Note that if the local server is + # joining a remote room for the first time we'll see both the joining + # user and all remote users as newly joined. + newly_joined_users = set() for delta in deltas: + assert room_id == delta["room_id"] + typ = delta["type"] state_key = delta["state_key"] - room_id = delta["room_id"] event_id = delta["event_id"] prev_event_id = delta["prev_event_id"] @@ -1227,72 +1244,55 @@ class PresenceHandler(BasePresenceHandler): # Ignore changes to join events. continue - # Retrieve any user presence state updates that need to be sent as a result, - # and the destinations that need to receive it - destinations, user_presence_states = await self._on_user_joined_room( - room_id, state_key - ) - - # Insert the destinations and respective updates into our destinations dict - for destination in destinations: - presence_destinations.setdefault(destination, set()).update( - user_presence_states - ) - - # Send out user presence updates for each destination - for destination, user_state_set in presence_destinations.items(): - self._federation_queue.send_presence_to_destinations( - destinations=[destination], states=user_state_set - ) + newly_joined_users.add(state_key) - async def _on_user_joined_room( - self, room_id: str, user_id: str - ) -> Tuple[List[str], List[UserPresenceState]]: - """Called when we detect a user joining the room via the current state - delta stream. Returns the destinations that need to be updated and the - presence updates to send to them. - - Args: - room_id: The ID of the room that the user has joined. - user_id: The ID of the user that has joined the room. - - Returns: - A tuple of destinations and presence updates to send to them. - """ - if self.is_mine_id(user_id): - # If this is a local user then we need to send their presence - # out to hosts in the room (who don't already have it) - - # TODO: We should be able to filter the hosts down to those that - # haven't previously seen the user - - remote_hosts = await self.state.get_current_hosts_in_room(room_id) - - # Filter out ourselves. - filtered_remote_hosts = [ - host for host in remote_hosts if host != self.server_name - ] - - state = await self.current_state_for_user(user_id) - return filtered_remote_hosts, [state] - else: - # A remote user has joined the room, so we need to: - # 1. Check if this is a new server in the room - # 2. If so send any presence they don't already have for - # local users in the room. - - # TODO: We should be able to filter the users down to those that - # the server hasn't previously seen - - # TODO: Check that this is actually a new server joining the - # room. - - remote_host = get_domain_from_id(user_id) + if not newly_joined_users: + # If nobody has joined then there's nothing to do. + return - users = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, users)) + # We want to send: + # 1. presence states of all local users in the room to newly joined + # remote servers + # 2. presence states of newly joined users to all remote servers in + # the room. + # + # TODO: Only send presence states to remote hosts that don't already + # have them (because they already share rooms). + + # Get all the users who were already in the room, by fetching the + # current users in the room and removing the newly joined users. + users = await self.store.get_users_in_room(room_id) + prev_users = set(users) - newly_joined_users + + # Construct sets for all the local users and remote hosts that were + # already in the room + prev_local_users = [] + prev_remote_hosts = set() + for user_id in prev_users: + if self.is_mine_id(user_id): + prev_local_users.append(user_id) + else: + prev_remote_hosts.add(get_domain_from_id(user_id)) + + # Similarly, construct sets for all the local users and remote hosts + # that were *not* already in the room. Care needs to be taken with the + # calculating the remote hosts, as a host may have already been in the + # room even if there is a newly joined user from that host. + newly_joined_local_users = [] + newly_joined_remote_hosts = set() + for user_id in newly_joined_users: + if self.is_mine_id(user_id): + newly_joined_local_users.append(user_id) + else: + host = get_domain_from_id(user_id) + if host not in prev_remote_hosts: + newly_joined_remote_hosts.add(host) - states_d = await self.current_state_for_users(user_ids) + # Send presence states of all local users in the room to newly joined + # remote servers. (We actually only send states for local users already + # in the room, as we'll send states for newly joined local users below.) + if prev_local_users and newly_joined_remote_hosts: + local_states = await self.current_state_for_users(prev_local_users) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -1302,16 +1302,30 @@ class PresenceHandler(BasePresenceHandler): now = self.clock.time_msec() states = [ state - for state in states_d.values() + for state in local_states.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None ] - return [remote_host], states + self._federation_queue.send_presence_to_destinations( + destinations=newly_joined_remote_hosts, + states=states, + ) + # Send presence states of newly joined users to all remote servers in + # the room + if newly_joined_local_users and ( + prev_remote_hosts or newly_joined_remote_hosts + ): + local_states = await self.current_state_for_users(newly_joined_local_users) + self._federation_queue.send_presence_to_destinations( + destinations=prev_remote_hosts | newly_joined_remote_hosts, + states=list(local_states.values()), + ) -def should_notify(old_state, new_state): + +def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool: """Decides if a presence state change should be sent to interested parties.""" if old_state == new_state: return False @@ -1347,7 +1361,9 @@ def should_notify(old_state, new_state): return False -def format_user_presence_state(state, now, include_user_id=True): +def format_user_presence_state( + state: UserPresenceState, now: int, include_user_id: bool = True +) -> JsonDict: """Convert UserPresenceState to a format that can be sent down to clients and to other servers. @@ -1385,11 +1401,11 @@ class PresenceEventSource: @log_function async def get_new_events( self, - user, - from_key, - room_ids=None, - include_offline=True, - explicit_room_id=None, + user: UserID, + from_key: Optional[int], + room_ids: Optional[List[str]] = None, + include_offline: bool = True, + explicit_room_id: Optional[str] = None, **kwargs, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: @@ -1594,7 +1610,7 @@ class PresenceEventSource: if update.state != PresenceState.OFFLINE ] - def get_current_key(self): + def get_current_key(self) -> int: return self.store.get_current_presence_token() @cached(num_args=2, cache_context=True) @@ -1654,15 +1670,20 @@ class PresenceEventSource: return users_interested_in -def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): +def handle_timeouts( + user_states: List[UserPresenceState], + is_mine_fn: Callable[[str], bool], + syncing_user_ids: Set[str], + now: int, +) -> List[UserPresenceState]: """Checks the presence of users that have timed out and updates as appropriate. Args: - user_states(list): List of UserPresenceState's to check. - is_mine_fn (fn): Function that returns if a user_id is ours - syncing_user_ids (set): Set of user_ids with active syncs. - now (int): Current time in ms. + user_states: List of UserPresenceState's to check. + is_mine_fn: Function that returns if a user_id is ours + syncing_user_ids: Set of user_ids with active syncs. + now: Current time in ms. Returns: List of UserPresenceState updates @@ -1679,14 +1700,16 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): return list(changes.values()) -def handle_timeout(state, is_mine, syncing_user_ids, now): +def handle_timeout( + state: UserPresenceState, is_mine: bool, syncing_user_ids: Set[str], now: int +) -> Optional[UserPresenceState]: """Checks the presence of the user to see if any of the timers have elapsed Args: - state (UserPresenceState) - is_mine (bool): Whether the user is ours - syncing_user_ids (set): Set of user_ids with active syncs. - now (int): Current time in ms. + state + is_mine: Whether the user is ours + syncing_user_ids: Set of user_ids with active syncs. + now: Current time in ms. Returns: A UserPresenceState update or None if no update. @@ -1738,23 +1761,29 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): return state if changed else None -def handle_update(prev_state, new_state, is_mine, wheel_timer, now): +def handle_update( + prev_state: UserPresenceState, + new_state: UserPresenceState, + is_mine: bool, + wheel_timer: WheelTimer, + now: int, +) -> Tuple[UserPresenceState, bool, bool]: """Given a presence update: 1. Add any appropriate timers. 2. Check if we should notify anyone. Args: - prev_state (UserPresenceState) - new_state (UserPresenceState) - is_mine (bool): Whether the user is ours - wheel_timer (WheelTimer) - now (int): Time now in ms + prev_state + new_state + is_mine: Whether the user is ours + wheel_timer + now: Time now in ms Returns: 3-tuple: `(new_state, persist_and_notify, federation_ping)` where: - new_state: is the state to actually persist - - persist_and_notify (bool): whether to persist and notify people - - federation_ping (bool): whether we should send a ping over federation + - persist_and_notify: whether to persist and notify people + - federation_ping: whether we should send a ping over federation """ user_id = new_state.user_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3d4869067f..7830efb9dc 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -802,9 +802,7 @@ class RegistrationHandler(BaseHandler): ) if is_guest: assert valid_until_ms is None - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) + access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: access_token = await self._auth_handler.get_access_token_for_user_id( user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1a99a0c827..58739e5016 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -32,7 +32,14 @@ from synapse.api.constants import ( RoomCreationPreset, RoomEncryptionAlgorithms, ) -from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + NotFoundError, + StoreError, + SynapseError, +) from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase @@ -126,10 +133,6 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() - self._invite_burst_count = ( - hs.config.ratelimiting.rc_invites_per_room.burst_count - ) - async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -694,8 +697,18 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] - if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: - raise SynapseError(400, "Cannot invite so many users at once") + if invite_list or invite_3pid_list: + try: + # If there are invites in the request, see if the ratelimiting settings + # allow that number of invites to be sent from the current user. + await self.room_member_handler.ratelimit_multiple_invites( + requester, + room_id=None, + n_invites=len(invite_list) + len(invite_3pid_list), + update=False, + ) + except LimitExceededError: + raise SynapseError(400, "Cannot invite so many users at once") await self.event_creation_handler.assert_accepted_privacy_policy(requester) @@ -1348,7 +1361,7 @@ class RoomShutdownHandler: new_room_id = None logger.info("Shutting down room %r", room_id) - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] for user_id in users: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 31c11280e8..835d5862c1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -210,6 +210,31 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def forget(self, user: UserID, room_id: str) -> None: raise NotImplementedError() + async def ratelimit_multiple_invites( + self, + requester: Optional[Requester], + room_id: Optional[str], + n_invites: int, + update: bool = True, + ): + """Ratelimit more than one invite sent by the given requester in the given room. + + Args: + requester: The requester sending the invites. + room_id: The room the invites are being sent in. + n_invites: The amount of invites to ratelimit for. + update: Whether to update the ratelimiter's cache. + + Raises: + LimitExceededError: The requester can't send that many invites in the room. + """ + await self._invites_per_room_limiter.ratelimit( + requester, + room_id, + update=update, + n_actions=n_invites, + ) + async def ratelimit_invite( self, requester: Optional[Requester], @@ -1170,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): class RoomMemberMasterHandler(RoomMemberHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.distributor = hs.get_distributor() diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 01e3e050f9..e35d91832b 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py
@@ -14,6 +14,7 @@ import itertools import logging +import re from collections import deque from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple, cast @@ -226,6 +227,23 @@ class SpaceSummaryHandler: suggested_only: bool, max_children: Optional[int], ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + """ + Generate a room entry and a list of event entries for a given room. + + Args: + requester: The requesting user, or None if this is over federation. + room_id: The room ID to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: The maximum number of children to return for this node. + + Returns: + A tuple of: + An iterable of a single value of the room. + + An iterable of the sorted children events. This may be limited + to a maximum size or may include all children. + """ if not await self._is_room_accessible(room_id, requester): return (), () @@ -288,6 +306,7 @@ class SpaceSummaryHandler: ev.data for ev in res.events if ev.event_type == EventTypes.MSC1772_SPACE_CHILD + or ev.event_type == EventTypes.SpaceChild ) async def _is_room_accessible(self, room_id: str, requester: Optional[str]) -> bool: @@ -331,7 +350,9 @@ class SpaceSummaryHandler: ) # TODO: update once MSC1772 lands - room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + if not room_type: + room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) entry = { "room_id": stats["room_id"], @@ -344,6 +365,7 @@ class SpaceSummaryHandler: stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), "guest_can_join": stats["guest_access"] == "can_join", + "creation_ts": create_event.origin_server_ts, "room_type": room_type, } @@ -353,6 +375,18 @@ class SpaceSummaryHandler: return room_entry async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + """ + Get the child events for a given room. + + The returned results are sorted for stability. + + Args: + room_id: The room id to get the children of. + + Returns: + An iterable of sorted child events. + """ + # look for child rooms/spaces. current_state_ids = await self._store.get_current_state_ids(room_id) @@ -360,13 +394,15 @@ class SpaceSummaryHandler: [ event_id for key, event_id in current_state_ids.items() - # TODO: update once MSC1772 lands + # TODO: update once MSC1772 has been FCP for a period of time. if key[0] == EventTypes.MSC1772_SPACE_CHILD + or key[0] == EventTypes.SpaceChild ] ) - # filter out any events without a "via" (which implies it has been redacted) - return (e for e in events if _has_valid_via(e)) + # filter out any events without a "via" (which implies it has been redacted), + # and order to ensure we return stable results. + return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) @attr.s(frozen=True, slots=True) @@ -392,3 +428,39 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool: return True logger.debug("Ignorning not-suggested child %s", edge_event.state_key) return False + + +# Order may only contain characters in the range of \x20 (space) to \x7F (~). +_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7F]") + + +def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: + """ + Generate a value for comparing two child events for ordering. + + The rules for ordering are supposed to be: + + 1. The 'order' key, if it is valid. + 2. The 'origin_server_ts' of the 'm.room.create' event. + 3. The 'room_id'. + + But we skip step 2 since we may not have any state from the room. + + Args: + child: The event for generating a comparison key. + + Returns: + The comparison key as a tuple of: + False if the ordering is valid. + The ordering field. + The room ID. + """ + order = child.content.get("order") + # If order is not a string or doesn't meet the requirements, ignore it. + if not isinstance(order, str): + order = None + elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): + order = None + + # Items without an order come last. + return (order is None, order, child.room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a1b0aee355..3bc02fb406 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -1210,7 +1210,7 @@ class SyncHandler: # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = await self.state.get_current_users_in_room(room_id) + joined_users = await self.store.get_users_in_room(room_id) newly_joined_or_invited_or_knocked_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they @@ -1226,7 +1226,7 @@ class SyncHandler: # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = await self.state.get_current_users_in_room(room_id) + left_users = await self.store.get_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1381,7 +1381,7 @@ class SyncHandler: extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = await self.state.get_current_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 0eeb7c03f2..5414ce77d8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any from twisted.web.client import PartialDownloadError @@ -22,13 +22,16 @@ from synapse.api.errors import Codes, LoginError, SynapseError from synapse.config.emailconfig import ThreepidBehaviour from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class UserInteractiveAuthChecker: """Abstract base class for an interactive auth checker""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): pass def is_enabled(self) -> bool: @@ -57,10 +60,10 @@ class UserInteractiveAuthChecker: class DummyAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.DUMMY - def is_enabled(self): + def is_enabled(self) -> bool: return True - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return True @@ -70,24 +73,24 @@ class TermsAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return True class RecaptchaAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.RECAPTCHA - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key - def is_enabled(self): + def is_enabled(self) -> bool: return self._enabled - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: try: user_response = authdict["response"] except KeyError: @@ -132,11 +135,11 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): class _BaseThreepidAuthChecker: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() - async def _check_threepid(self, medium, authdict): + async def _check_threepid(self, medium: str, authdict: dict) -> dict: if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) @@ -206,31 +209,31 @@ class _BaseThreepidAuthChecker: class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): AUTH_TYPE = LoginType.EMAIL_IDENTITY - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): UserInteractiveAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs) - def is_enabled(self): + def is_enabled(self) -> bool: return self.hs.config.threepid_behaviour_email in ( ThreepidBehaviour.REMOTE, ThreepidBehaviour.LOCAL, ) - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("email", authdict) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): AUTH_TYPE = LoginType.MSISDN - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): UserInteractiveAuthChecker.__init__(self, hs) _BaseThreepidAuthChecker.__init__(self, hs) - def is_enabled(self): + def is_enabled(self) -> bool: return bool(self.hs.config.account_threepid_delegate_msisdn) - async def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict)