diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 66ce7e8b83..5b927f10b3 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -17,7 +17,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import StoreError, SynapseError
from synapse.logging.context import make_deferred_yieldable
@@ -39,28 +39,44 @@ class AccountValidityHandler:
self.sendmail = self.hs.get_sendmail()
self.clock = self.hs.get_clock()
- self._account_validity = self.hs.config.account_validity
+ self._account_validity_enabled = (
+ hs.config.account_validity.account_validity_enabled
+ )
+ self._account_validity_renew_by_email_enabled = (
+ hs.config.account_validity.account_validity_renew_by_email_enabled
+ )
+
+ self._account_validity_period = None
+ if self._account_validity_enabled:
+ self._account_validity_period = (
+ hs.config.account_validity.account_validity_period
+ )
if (
- self._account_validity.enabled
- and self._account_validity.renew_by_email_enabled
+ self._account_validity_enabled
+ and self._account_validity_renew_by_email_enabled
):
# Don't do email-specific configuration if renewal by email is disabled.
- self._template_html = self.config.account_validity_template_html
- self._template_text = self.config.account_validity_template_text
+ self._template_html = (
+ hs.config.account_validity.account_validity_template_html
+ )
+ self._template_text = (
+ hs.config.account_validity.account_validity_template_text
+ )
+ account_validity_renew_email_subject = (
+ hs.config.account_validity.account_validity_renew_email_subject
+ )
try:
- app_name = self.hs.config.email_app_name
+ app_name = hs.config.email_app_name
- self._subject = self._account_validity.renew_email_subject % {
- "app": app_name
- }
+ self._subject = account_validity_renew_email_subject % {"app": app_name}
- self._from_string = self.hs.config.email_notif_from % {"app": app_name}
+ self._from_string = hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
- self._subject = self._account_validity.renew_email_subject
- self._from_string = self.hs.config.email_notif_from
+ self._subject = account_validity_renew_email_subject
+ self._from_string = hs.config.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
@@ -220,50 +236,87 @@ class AccountValidityHandler:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- async def renew_account(self, renewal_token: str) -> bool:
+ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
+ If it turns out that the token is valid but has already been used, then the
+ token is considered stale. A token is stale if the 'token_used_ts_ms' db column
+ is non-null.
+
Args:
renewal_token: Token sent with the renewal request.
Returns:
- Whether the provided token is valid.
+ A tuple containing:
+ * A bool representing whether the token is valid and unused.
+ * A bool which is `True` if the token is valid, but stale.
+ * An int representing the user's expiry timestamp as milliseconds since the
+ epoch, or 0 if the token was invalid.
"""
try:
- user_id = await self.store.get_user_from_renewal_token(renewal_token)
+ (
+ user_id,
+ current_expiration_ts,
+ token_used_ts,
+ ) = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- return False
+ return False, False, 0
+
+ # Check whether this token has already been used.
+ if token_used_ts:
+ logger.info(
+ "User '%s' attempted to use previously used token '%s' to renew account",
+ user_id,
+ renewal_token,
+ )
+ return False, True, current_expiration_ts
logger.debug("Renewing an account for user %s", user_id)
- await self.renew_account_for_user(user_id)
- return True
+ # Renew the account. Pass the renewal_token here so that it is not cleared.
+ # We want to keep the token around in case the user attempts to renew their
+ # account with the same token twice (clicking the email link twice).
+ #
+ # In that case, the token will be accepted, but the account's expiration ts
+ # will remain unchanged.
+ new_expiration_ts = await self.renew_account_for_user(
+ user_id, renewal_token=renewal_token
+ )
+
+ return True, False, new_expiration_ts
async def renew_account_for_user(
self,
user_id: str,
expiration_ts: Optional[int] = None,
email_sent: bool = False,
+ renewal_token: Optional[str] = None,
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token: Token sent with the renewal request.
+ user_id: The ID of the user to renew.
expiration_ts: New expiration date. Defaults to now + validity period.
- email_sen: Whether an email has been sent for this validity period.
- Defaults to False.
+ email_sent: Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request. The user's token
+ will be cleared if this is None.
Returns:
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
+ now = self.clock.time_msec()
if expiration_ts is None:
- expiration_ts = self.clock.time_msec() + self._account_validity.period
+ expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user(
- user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
+ user_id=user_id,
+ expiration_ts=expiration_ts,
+ email_sent=email_sent,
+ renewal_token=renewal_token,
+ token_used_ts=now,
)
return expiration_ts
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index d7bc4e23ed..177310f0be 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
from prometheus_client import Counter
@@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
+from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.metrics import Measure
if TYPE_CHECKING:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index b8a37b6477..36f2450e2e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1248,7 +1248,7 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- for token, token_id, device_id in tokens_and_devices:
+ for token, _, device_id in tokens_and_devices:
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas.py
index 7346ccfe93..7346ccfe93 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas.py
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 3f6f9f7f3d..45d2404dde 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -49,7 +49,9 @@ class DeactivateAccountHandler(BaseHandler):
if hs.config.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting)
- self._account_validity_enabled = hs.config.account_validity.enabled
+ self._account_validity_enabled = (
+ hs.config.account_validity.account_validity_enabled
+ )
async def deactivate_account(
self,
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d75edb184b..95bdc5902a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -28,7 +28,6 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
- Collection,
JsonDict,
StreamToken,
UserID,
@@ -156,8 +155,7 @@ class DeviceWorkerHandler(BaseHandler):
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
- for key, event_id in current_state_ids.items():
- etype, state_key = key
+ for etype, state_key in current_state_ids.keys():
if etype != EventTypes.Member:
continue
possibly_left.add(state_key)
@@ -179,8 +177,7 @@ class DeviceWorkerHandler(BaseHandler):
log_kv(
{"event": "encountered empty previous state", "room_id": room_id}
)
- for key, event_id in current_state_ids.items():
- etype, state_key = key
+ for etype, state_key in current_state_ids.keys():
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
@@ -198,8 +195,7 @@ class DeviceWorkerHandler(BaseHandler):
for state_dict in prev_state_ids.values():
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
- for key, event_id in current_state_ids.items():
- etype, state_key = key
+ for etype, state_key in current_state_ids.keys():
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
@@ -714,7 +710,7 @@ class DeviceListUpdater:
# This can happen since we batch updates
return
- for device_id, stream_id, prev_ids, content in pending_updates:
+ for device_id, stream_id, prev_ids, _ in pending_updates:
logger.debug(
"Handling update %r/%r, ID: %r, prev: %r ",
user_id,
@@ -740,7 +736,7 @@ class DeviceListUpdater:
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
- for device_id, stream_id, prev_ids, content in pending_updates:
+ for device_id, stream_id, _, content in pending_updates:
await self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id
)
@@ -929,6 +925,10 @@ class DeviceListUpdater:
else:
cached_devices = await self.store.get_cached_devices_for_user(user_id)
if cached_devices == {d["device_id"]: d for d in devices}:
+ logging.info(
+ "Skipping device list resync for %s, as our cache matches already",
+ user_id,
+ )
devices = []
ignore_devices = True
@@ -944,6 +944,9 @@ class DeviceListUpdater:
await self.store.update_remote_device_list_cache(
user_id, devices, stream_id
)
+ # mark the cache as valid, whether or not we actually processed any device
+ # list updates.
+ await self.store.mark_remote_user_device_cache_as_valid(user_id)
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
new file mode 100644
index 0000000000..eff639f407
--- /dev/null
+++ b/synapse/handlers/event_auth.py
@@ -0,0 +1,86 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.room_versions import RoomVersion
+from synapse.types import StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class EventAuthHandler:
+ """
+ This class contains methods for authenticating events added to room graphs.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+
+ async def can_join_without_invite(
+ self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
+ ) -> bool:
+ """
+ Check whether a user can join a room without an invite.
+
+ When joining a room with restricted joined rules (as defined in MSC3083),
+ the membership of spaces must be checked during join.
+
+ Args:
+ state_ids: The state of the room as it currently is.
+ room_version: The room version of the room being joined.
+ user_id: The user joining the room.
+
+ Returns:
+ True if the user can join the room, false otherwise.
+ """
+ # This only applies to room versions which support the new join rule.
+ if not room_version.msc3083_join_rules:
+ return True
+
+ # If there's no join rule, then it defaults to invite (so this doesn't apply).
+ join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ if not join_rules_event_id:
+ return True
+
+ # If the join rule is not restricted, this doesn't apply.
+ join_rules_event = await self._store.get_event(join_rules_event_id)
+ if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
+ return True
+
+ # If allowed is of the wrong form, then only allow invited users.
+ allowed_spaces = join_rules_event.content.get("allow", [])
+ if not isinstance(allowed_spaces, list):
+ return False
+
+ # Get the list of joined rooms and see if there's an overlap.
+ joined_rooms = await self._store.get_rooms_for_user(user_id)
+
+ # Pull out the other room IDs, invalid data gets filtered.
+ for space in allowed_spaces:
+ if not isinstance(space, dict):
+ continue
+
+ space_id = space.get("space")
+ if not isinstance(space_id, str):
+ continue
+
+ # The user was joined to one of the spaces specified, they can join
+ # this room!
+ if space_id in joined_rooms:
+ return True
+
+ # The user was not in any of the required spaces.
+ return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4b3730aa3b..9d867aaf4d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -146,6 +146,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
@@ -1673,8 +1674,40 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
+ # Calculate the event context.
context = await self.state_handler.compute_event_context(event)
- context = await self._auth_and_persist_event(origin, event, context)
+
+ # Get the state before the new event.
+ prev_state_ids = await context.get_prev_state_ids()
+
+ # Check if the user is already in the room or invited to the room.
+ user_id = event.state_key
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ newly_joined = True
+ user_is_invited = False
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ newly_joined = prev_member_event.membership != Membership.JOIN
+ user_is_invited = prev_member_event.membership == Membership.INVITE
+
+ # If the member is not already in the room, and not invited, check if
+ # they should be allowed access via membership in a space.
+ if (
+ newly_joined
+ and not user_is_invited
+ and not await self._event_auth_handler.can_join_without_invite(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ )
+ ):
+ raise AuthError(
+ 403,
+ "You do not belong to any of the required spaces to join this room.",
+ )
+
+ # Persist the event.
+ await self._auth_and_persist_event(origin, event, context)
logger.debug(
"on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
@@ -1682,8 +1715,6 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- prev_state_ids = await context.get_prev_state_ids()
-
state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
@@ -2006,7 +2037,7 @@ class FederationHandler(BaseHandler):
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
- ) -> EventContext:
+ ) -> None:
"""
Process an event by performing auth checks and then persisting to the database.
@@ -2028,9 +2059,6 @@ class FederationHandler(BaseHandler):
event is an outlier), may be the auth events claimed by the remote
server.
backfilled: True if the event was backfilled.
-
- Returns:
- The event context.
"""
context = await self._check_event_auth(
origin,
@@ -2060,8 +2088,6 @@ class FederationHandler(BaseHandler):
)
raise
- return context
-
async def _auth_and_persist_events(
self,
origin: str,
@@ -2956,7 +2982,7 @@ class FederationHandler(BaseHandler):
try:
# for each sig on the third_party_invite block of the actual invite
for server, signature_block in signed["signatures"].items():
- for key_name, encoded_signature in signature_block.items():
+ for key_name in signature_block.keys():
if not key_name.startswith("ed25519:"):
continue
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 87a8b89237..0b3b1fadb5 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -15,7 +15,6 @@
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
-
import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
@@ -34,7 +33,11 @@ from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
-from synapse.util.stringutils import assert_valid_client_secret, random_string
+from synapse.util.stringutils import (
+ assert_valid_client_secret,
+ random_string,
+ valid_id_server_location,
+)
from ._base import BaseHandler
@@ -172,6 +175,11 @@ class IdentityHandler(BaseHandler):
server with, if necessary. Required if use_v2 is true
use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
+ Raises:
+ SynapseError: On any of the following conditions
+ - the supplied id_server is not a valid identity server name
+ - we failed to contact the supplied identity server
+
Returns:
The response from the identity server
"""
@@ -181,6 +189,12 @@ class IdentityHandler(BaseHandler):
if id_access_token is None:
use_v2 = False
+ if not valid_id_server_location(id_server):
+ raise SynapseError(
+ 400,
+ "id_server must be a valid hostname with optional port and path components",
+ )
+
# Decide which API endpoint URLs to use
headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
@@ -269,12 +283,21 @@ class IdentityHandler(BaseHandler):
id_server: Identity server to unbind from
Raises:
- SynapseError: If we failed to contact the identity server
+ SynapseError: On any of the following conditions
+ - the supplied id_server is not a valid identity server name
+ - we failed to contact the supplied identity server
Returns:
True on success, otherwise False if the identity
server doesn't support unbinding
"""
+
+ if not valid_id_server_location(id_server):
+ raise SynapseError(
+ 400,
+ "id_server must be a valid hostname with optional port and path components",
+ )
+
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc.py
index b156196a70..ee6e41c0e4 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc.py
@@ -15,7 +15,7 @@
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
-from urllib.parse import urlencode
+from urllib.parse import urlencode, urlparse
import attr
import pymacaroons
@@ -37,10 +37,7 @@ from twisted.web.client import readBody
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
-from synapse.config.oidc_config import (
- OidcProviderClientSecretJwtKey,
- OidcProviderConfig,
-)
+from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@@ -71,8 +68,8 @@ logger = logging.getLogger(__name__)
#
# Here we have the names of the cookies, and the options we use to set them.
_SESSION_COOKIES = [
- (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
- (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
+ (b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
+ (b"oidc_session_no_samesite", b"HttpOnly"),
]
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
@@ -282,6 +279,13 @@ class OidcProvider:
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
+ # Calculate the prefix for OIDC callback paths based on the public_baseurl.
+ # We'll insert this into the Path= parameter of any session cookies we set.
+ public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
+ self._callback_path_prefix = (
+ public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
+ )
+
self._oidc_attribute_requirements = provider.attribute_requirements
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
@@ -782,8 +786,13 @@ class OidcProvider:
for cookie_name, options in _SESSION_COOKIES:
request.cookies.append(
- b"%s=%s; Max-Age=3600; %s"
- % (cookie_name, cookie.encode("utf-8"), options)
+ b"%s=%s; Max-Age=3600; Path=%s; %s"
+ % (
+ cookie_name,
+ cookie.encode("utf-8"),
+ self._callback_path_prefix,
+ options,
+ )
)
metadata = await self.load_metadata()
@@ -960,6 +969,11 @@ class OidcProvider:
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
+ if attributes.localpart is None:
+ # If no localpart is returned then we will generate one, so
+ # there is no need to search for existing users.
+ return None
+
user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6460eb9952..969c73c1e7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,9 +24,11 @@ The methods that define policy are:
import abc
import contextlib
import logging
+from bisect import bisect
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
+ Collection,
Dict,
FrozenSet,
Iterable,
@@ -53,10 +55,11 @@ from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
)
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
-from synapse.state import StateHandler
+from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
-from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.metrics import Measure
@@ -118,19 +121,21 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class BasePresenceHandler(abc.ABC):
- """Parts of the PresenceHandler that are shared between workers and master"""
+ """Parts of the PresenceHandler that are shared between workers and presence
+ writer"""
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.presence_router = hs.get_presence_router()
self.state = hs.get_state_handler()
+ self.is_mine_id = hs.is_mine_id
self._federation = None
- if hs.should_send_federation() or not hs.config.worker_app:
+ if hs.should_send_federation():
self._federation = hs.get_federation_sender()
- self._send_federation = hs.should_send_federation()
+ self._federation_queue = PresenceFederationQueue(hs, self)
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
@@ -253,28 +258,38 @@ class BasePresenceHandler(abc.ABC):
"""
pass
- async def process_replication_rows(self, token, rows):
- """Process presence stream rows received over replication."""
- pass
+ async def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ """Process streams received over replication."""
+ await self._federation_queue.process_replication_rows(
+ stream_name, instance_name, token, rows
+ )
+
+ def get_federation_queue(self) -> "PresenceFederationQueue":
+ """Get the presence federation queue."""
+ return self._federation_queue
async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState]
):
"""If this instance is a federation sender, send the states to all
- destinations that are interested.
+ destinations that are interested. Filters out any states for remote
+ users.
"""
- if not self._send_federation:
+ if not self._federation:
return
- # If this worker sends federation we must have a FederationSender.
- assert self._federation
+ states = [s for s in states if self.is_mine_id(s.user_id)]
+
+ if not states:
+ return
hosts_and_states = await get_interested_remotes(
self.store,
self.presence_router,
states,
- self.state,
)
for destinations, states in hosts_and_states:
@@ -292,10 +307,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
def __init__(self, hs):
super().__init__(hs)
self.hs = hs
- self.is_mine_id = hs.is_mine_id
+
+ self._presence_writer_instance = hs.config.worker.writers.presence[0]
self._presence_enabled = hs.config.use_presence
+ # Route presence EDUs to the right worker
+ hs.get_federation_registry().register_instances_for_edu(
+ "m.presence",
+ hs.config.worker.writers.presence,
+ )
+
# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
self._user_to_num_current_syncs = {} # type: Dict[str, int]
@@ -303,8 +325,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
- # user_id -> last_sync_ms. Lists the users that have stopped syncing
- # but we haven't notified the master of that yet
+ # user_id -> last_sync_ms. Lists the users that have stopped syncing but
+ # we haven't notified the presence writer of that yet
self.users_going_offline = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
@@ -337,22 +359,23 @@ class WorkerPresenceHandler(BasePresenceHandler):
)
def mark_as_coming_online(self, user_id):
- """A user has started syncing. Send a UserSync to the master, unless they
- had recently stopped syncing.
+ """A user has started syncing. Send a UserSync to the presence writer,
+ unless they had recently stopped syncing.
Args:
user_id (str)
"""
going_offline = self.users_going_offline.pop(user_id, None)
if not going_offline:
- # Safe to skip because we haven't yet told the master they were offline
+ # Safe to skip because we haven't yet told the presence writer they
+ # were offline
self.send_user_sync(user_id, True, self.clock.time_msec())
def mark_as_going_offline(self, user_id):
- """A user has stopped syncing. We wait before notifying the master as
- its likely they'll come back soon. This allows us to avoid sending
- a stopped syncing immediately followed by a started syncing notification
- to the master
+ """A user has stopped syncing. We wait before notifying the presence
+ writer as its likely they'll come back soon. This allows us to avoid
+ sending a stopped syncing immediately followed by a started syncing
+ notification to the presence writer
Args:
user_id (str)
@@ -360,8 +383,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.users_going_offline[user_id] = self.clock.time_msec()
def send_stop_syncing(self):
- """Check if there are any users who have stopped syncing a while ago
- and haven't come back yet. If there are poke the master about them.
+ """Check if there are any users who have stopped syncing a while ago and
+ haven't come back yet. If there are poke the presence writer about them.
"""
now = self.clock.time_msec()
for user_id, last_sync_ms in list(self.users_going_offline.items()):
@@ -421,7 +444,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# If this is a federation sender, notify about presence updates.
await self.maybe_send_presence_to_interested_destinations(states)
- async def process_replication_rows(self, token, rows):
+ async def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ await super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name != PresenceStream.NAME:
+ return
+
states = [
UserPresenceState(
row.user_id,
@@ -470,9 +500,12 @@ class WorkerPresenceHandler(BasePresenceHandler):
if not self.hs.config.use_presence:
return
- # Proxy request to master
+ # Proxy request to instance that writes presence
await self._set_state_client(
- user_id=user_id, state=state, ignore_status_msg=ignore_status_msg
+ instance_name=self._presence_writer_instance,
+ user_id=user_id,
+ state=state,
+ ignore_status_msg=ignore_status_msg,
)
async def bump_presence_active_time(self, user):
@@ -483,16 +516,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
if not self.hs.config.use_presence:
return
- # Proxy request to master
+ # Proxy request to instance that writes presence
user_id = user.to_string()
- await self._bump_active_client(user_id=user_id)
+ await self._bump_active_client(
+ instance_name=self._presence_writer_instance, user_id=user_id
+ )
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
- self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
@@ -721,15 +755,12 @@ class PresenceHandler(BasePresenceHandler):
self.store,
self.presence_router,
list(to_federation_ping.values()),
- self.state,
)
- # Since this is master we know that we have a federation sender or
- # queue, and so this will be defined.
- assert self._federation
-
for destinations, states in hosts_and_states:
- self._federation.send_presence_to_destinations(states, destinations)
+ self._federation_queue.send_presence_to_destinations(
+ states, destinations
+ )
async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
@@ -1208,13 +1239,9 @@ class PresenceHandler(BasePresenceHandler):
user_presence_states
)
- # Since this is master we know that we have a federation sender or
- # queue, and so this will be defined.
- assert self._federation
-
# Send out user presence updates for each destination
for destination, user_state_set in presence_destinations.items():
- self._federation.send_presence_to_destinations(
+ self._federation_queue.send_presence_to_destinations(
destinations=[destination], states=user_state_set
)
@@ -1354,7 +1381,6 @@ class PresenceEventSource:
self.get_presence_router = hs.get_presence_router
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
@log_function
async def get_new_events(
@@ -1823,7 +1849,6 @@ async def get_interested_remotes(
store: DataStore,
presence_router: PresenceRouter,
states: List[UserPresenceState],
- state_handler: StateHandler,
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1834,7 +1859,6 @@ async def get_interested_remotes(
store: The homeserver's data store.
presence_router: A module for augmenting the destinations for presence updates.
states: A list of incoming user presence updates.
- state_handler:
Returns:
A list of 2-tuples of destinations and states, where for
@@ -1851,7 +1875,8 @@ async def get_interested_remotes(
)
for room_id, states in room_ids_to_states.items():
- hosts = await state_handler.get_current_hosts_in_room(room_id)
+ user_ids = await store.get_users_in_room(room_id)
+ hosts = {get_domain_from_id(user_id) for user_id in user_ids}
hosts_and_states.append((hosts, states))
for user_id, states in users_to_states.items():
@@ -1859,3 +1884,198 @@ async def get_interested_remotes(
hosts_and_states.append(([host], states))
return hosts_and_states
+
+
+class PresenceFederationQueue:
+ """Handles sending ad hoc presence updates over federation, which are *not*
+ due to state updates (that get handled via the presence stream), e.g.
+ federation pings and sending existing present states to newly joined hosts.
+
+ Only the last N minutes will be queued, so if a federation sender instance
+ is down for longer then some updates will be dropped. This is OK as presence
+ is ephemeral, and so it will self correct eventually.
+
+ On workers the class tracks the last received position of the stream from
+ replication, and handles querying for missed updates over HTTP replication,
+ c.f. `get_current_token` and `get_replication_rows`.
+ """
+
+ # How long to keep entries in the queue for. Workers that are down for
+ # longer than this duration will miss out on older updates.
+ _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000
+
+ # How often to check if we can expire entries from the queue.
+ _CLEAR_ITEMS_EVERY_MS = 60 * 1000
+
+ def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
+ self._clock = hs.get_clock()
+ self._notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
+ self._presence_handler = presence_handler
+ self._repl_client = ReplicationGetStreamUpdates.make_client(hs)
+
+ # Should we keep a queue of recent presence updates? We only bother if
+ # another process may be handling federation sending.
+ self._queue_presence_updates = True
+
+ # Whether this instance is a presence writer.
+ self._presence_writer = self._instance_name in hs.config.worker.writers.presence
+
+ # The FederationSender instance, if this process sends federation traffic directly.
+ self._federation = None
+
+ if hs.should_send_federation():
+ self._federation = hs.get_federation_sender()
+
+ # We don't bother queuing up presence states if only this instance
+ # is sending federation.
+ if hs.config.worker.federation_shard_config.instances == [
+ self._instance_name
+ ]:
+ self._queue_presence_updates = False
+
+ # The queue of recently queued updates as tuples of: `(timestamp,
+ # stream_id, destinations, user_ids)`. We don't store the full states
+ # for efficiency, and remote workers will already have the full states
+ # cached.
+ self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
+
+ self._next_id = 1
+
+ # Map from instance name to current token
+ self._current_tokens = {} # type: Dict[str, int]
+
+ if self._queue_presence_updates:
+ self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
+
+ def _clear_queue(self):
+ """Clear out older entries from the queue."""
+ clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
+
+ # The queue is sorted by timestamp, so we can bisect to find the right
+ # place to purge before. Note that we are searching using a 1-tuple with
+ # the time, which does The Right Thing since the queue is a tuple where
+ # the first item is a timestamp.
+ index = bisect(self._queue, (clear_before,))
+ self._queue = self._queue[index:]
+
+ def send_presence_to_destinations(
+ self, states: Collection[UserPresenceState], destinations: Collection[str]
+ ) -> None:
+ """Send the presence states to the given destinations.
+
+ Will forward to the local federation sender (if there is one) and queue
+ to send over replication (if there are other federation sender instances.).
+
+ Must only be called on the presence writer process.
+ """
+
+ # This should only be called on a presence writer.
+ assert self._presence_writer
+
+ if self._federation:
+ self._federation.send_presence_to_destinations(
+ states=states,
+ destinations=destinations,
+ )
+
+ if not self._queue_presence_updates:
+ return
+
+ now = self._clock.time_msec()
+
+ stream_id = self._next_id
+ self._next_id += 1
+
+ self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))
+
+ self._notifier.notify_replication()
+
+ def get_current_token(self, instance_name: str) -> int:
+ """Get the current position of the stream.
+
+ On workers this returns the last stream ID received from replication.
+ """
+ if instance_name == self._instance_name:
+ return self._next_id - 1
+ else:
+ return self._current_tokens.get(instance_name, 0)
+
+ async def get_replication_rows(
+ self,
+ instance_name: str,
+ from_token: int,
+ upto_token: int,
+ target_row_count: int,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
+ """Get all the updates between the two tokens.
+
+ We return rows in the form of `(destination, user_id)` to keep the size
+ of each row bounded (rather than returning the sets in a row).
+
+ On workers this will query the presence writer process via HTTP replication.
+ """
+ if instance_name != self._instance_name:
+ # If not local we query over http replication from the presence
+ # writer
+ result = await self._repl_client(
+ instance_name=instance_name,
+ stream_name=PresenceFederationStream.NAME,
+ from_token=from_token,
+ upto_token=upto_token,
+ )
+ return result["updates"], result["upto_token"], result["limited"]
+
+ # We can find the correct position in the queue by noting that there is
+ # exactly one entry per stream ID, and that the last entry has an ID of
+ # `self._next_id - 1`, so we can count backwards from the end.
+ #
+ # Since the start of the queue is periodically truncated we need to
+ # handle the case where `from_token` stream ID has already been dropped.
+ start_idx = max(from_token - self._next_id, -len(self._queue))
+
+ to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
+ limited = False
+ new_id = upto_token
+ for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
+ if stream_id > upto_token:
+ break
+
+ new_id = stream_id
+
+ to_send.extend(
+ (stream_id, (destination, user_id))
+ for destination in destinations
+ for user_id in user_ids
+ )
+
+ if len(to_send) > target_row_count:
+ limited = True
+ break
+
+ return to_send, new_id, limited
+
+ async def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ if stream_name != PresenceFederationStream.NAME:
+ return
+
+ # We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect)
+ self._current_tokens[instance_name] = token
+
+ # If we're a federation sender we pull out the presence states to send
+ # and forward them on.
+ if not self._federation:
+ return
+
+ hosts_to_users = {} # type: Dict[str, Set[str]]
+ for row in rows:
+ hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
+
+ for host, user_ids in hosts_to_users.items():
+ states = await self._presence_handler.current_state_for_users(user_ids)
+ self._federation.send_presence_to_destinations(
+ states=states.values(),
+ destinations=[host],
+ )
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6e28677530..54c25e3557 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -19,7 +19,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
-from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -28,7 +28,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@@ -64,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
+ self.event_auth_handler = hs.get_event_auth_handler()
self.member_linearizer = Linearizer(name="member")
self.member_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -179,62 +179,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
- async def _can_join_without_invite(
- self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
- ) -> bool:
- """
- Check whether a user can join a room without an invite.
-
- When joining a room with restricted joined rules (as defined in MSC3083),
- the membership of spaces must be checked during join.
-
- Args:
- state_ids: The state of the room as it currently is.
- room_version: The room version of the room being joined.
- user_id: The user joining the room.
-
- Returns:
- True if the user can join the room, false otherwise.
- """
- # This only applies to room versions which support the new join rule.
- if not room_version.msc3083_join_rules:
- return True
-
- # If there's no join rule, then it defaults to public (so this doesn't apply).
- join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
- if not join_rules_event_id:
- return True
-
- # If the join rule is not restricted, this doesn't apply.
- join_rules_event = await self.store.get_event(join_rules_event_id)
- if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
- return True
-
- # If allowed is of the wrong form, then only allow invited users.
- allowed_spaces = join_rules_event.content.get("allow", [])
- if not isinstance(allowed_spaces, list):
- return False
-
- # Get the list of joined rooms and see if there's an overlap.
- joined_rooms = await self.store.get_rooms_for_user(user_id)
-
- # Pull out the other room IDs, invalid data gets filtered.
- for space in allowed_spaces:
- if not isinstance(space, dict):
- continue
-
- space_id = space.get("space")
- if not isinstance(space_id, str):
- continue
-
- # The user was joined to one of the spaces specified, they can join
- # this room!
- if space_id in joined_rooms:
- return True
-
- # The user was not in any of the required spaces.
- return False
-
async def _local_membership_update(
self,
requester: Requester,
@@ -303,7 +247,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if (
newly_joined
and not user_is_invited
- and not await self._can_join_without_invite(
+ and not await self.event_auth_handler.can_join_without_invite(
prev_state_ids, event.room_version, user_id
)
):
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml.py
index 80ba65b9e0..80ba65b9e0 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml.py
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 8d00ffdc73..044ff06d84 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -18,6 +18,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -40,7 +41,7 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
-from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
from synapse.util.stringutils import random_string
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d2ba805f86..3ffc4628cb 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -14,7 +14,17 @@
# limitations under the License.
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ FrozenSet,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -28,7 +38,6 @@ from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
- Collection,
JsonDict,
MutableStateMap,
Requester,
|