diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4b66a9862f..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
import unicodedata
import urllib.parse
from binascii import crc32
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -38,6 +39,7 @@ import attr
import bcrypt
import pymacaroons
import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
from twisted.web.server import Request
@@ -181,8 +183,11 @@ class LoginTokenAttributes:
user_id = attr.ib(type=str)
- # the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
+ """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+ auth_provider_session_id = attr.ib(type=Optional[str])
+ """The session ID advertised by the SSO Identity Provider."""
class AuthHandler:
@@ -756,53 +761,109 @@ class AuthHandler:
async def refresh_token(
self,
refresh_token: str,
- valid_until_ms: Optional[int],
- ) -> Tuple[str, str]:
+ access_token_valid_until_ms: Optional[int],
+ refresh_token_valid_until_ms: Optional[int],
+ ) -> Tuple[str, str, Optional[int]]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+ The lifetime of both the access token and refresh token will be capped so that they
+ do not exceed the session's ultimate expiry time, if applicable.
+
Args:
refresh_token: The token to consume.
- valid_until_ms: The expiration timestamp of the new access token.
-
+ access_token_valid_until_ms: The expiration timestamp of the new access token.
+ None if the access token does not expire.
+ refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+ None if the refresh token does not expire.
Returns:
- A tuple containing the new access token and refresh token
+ A tuple containing:
+ - the new access token
+ - the new refresh token
+ - the actual expiry time of the access token, which may be earlier than
+ `access_token_valid_until_ms`.
"""
# Verify the token signature first before looking up the token
if not self._verify_refresh_token(refresh_token):
- raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+ raise SynapseError(
+ HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+ )
existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
- raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+ raise SynapseError(
+ HTTPStatus.UNAUTHORIZED,
+ "refresh token does not exist",
+ Codes.UNKNOWN_TOKEN,
+ )
if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(
- 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+ HTTPStatus.FORBIDDEN,
+ "refresh token isn't valid anymore",
+ Codes.FORBIDDEN,
)
+ now_ms = self._clock.time_msec()
+
+ if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "The supplied refresh token has expired",
+ Codes.FORBIDDEN,
+ )
+
+ if existing_token.ultimate_session_expiry_ts is not None:
+ # This session has a bounded lifetime, even across refreshes.
+
+ if access_token_valid_until_ms is not None:
+ access_token_valid_until_ms = min(
+ access_token_valid_until_ms,
+ existing_token.ultimate_session_expiry_ts,
+ )
+ else:
+ access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+ if refresh_token_valid_until_ms is not None:
+ refresh_token_valid_until_ms = min(
+ refresh_token_valid_until_ms,
+ existing_token.ultimate_session_expiry_ts,
+ )
+ else:
+ refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+ if existing_token.ultimate_session_expiry_ts < now_ms:
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "The session has expired and can no longer be refreshed",
+ Codes.FORBIDDEN,
+ )
+
(
new_refresh_token,
new_refresh_token_id,
) = await self.create_refresh_token_for_user_id(
- user_id=existing_token.user_id, device_id=existing_token.device_id
+ user_id=existing_token.user_id,
+ device_id=existing_token.device_id,
+ expiry_ts=refresh_token_valid_until_ms,
+ ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
)
access_token = await self.create_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
- valid_until_ms=valid_until_ms,
+ valid_until_ms=access_token_valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
- return access_token, new_refresh_token
+ return access_token, new_refresh_token, access_token_valid_until_ms
def _verify_refresh_token(self, token: str) -> bool:
"""
@@ -836,6 +897,8 @@ class AuthHandler:
self,
user_id: str,
device_id: str,
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> Tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@@ -843,6 +906,13 @@ class AuthHandler:
Args:
user_id: canonical user ID
device_id: the device ID to associate with the token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Returns:
The newly created refresh token and its ID in the database
@@ -852,6 +922,8 @@ class AuthHandler:
user_id=user_id,
token=refresh_token,
device_id=device_id,
+ expiry_ts=expiry_ts,
+ ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
return refresh_token, refresh_token_id
@@ -1582,6 +1654,7 @@ class AuthHandler:
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
+ auth_provider_session_id: Optional[str] = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1597,6 +1670,7 @@ class AuthHandler:
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
+ auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1617,6 +1691,7 @@ class AuthHandler:
extra_attributes,
new_user=new_user,
user_profile_data=profile,
+ auth_provider_session_id=auth_provider_session_id,
)
def _complete_sso_login(
@@ -1628,6 +1703,7 @@ class AuthHandler:
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> None:
"""
The synchronous portion of complete_sso_login.
@@ -1649,7 +1725,9 @@ class AuthHandler:
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id, auth_provider_id=auth_provider_id
+ registered_user_id,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1754,6 +1832,7 @@ class MacaroonGenerator:
self,
user_id: str,
auth_provider_id: str,
+ auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
@@ -1762,6 +1841,10 @@ class MacaroonGenerator:
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+ if auth_provider_session_id is not None:
+ macaroon.add_first_party_caveat(
+ "auth_provider_session_id = %s" % (auth_provider_session_id,)
+ )
return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1783,15 +1866,28 @@ class MacaroonGenerator:
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+ auth_provider_session_id: Optional[str] = None
+ try:
+ auth_provider_session_id = get_value_from_macaroon(
+ macaroon, "auth_provider_session_id"
+ )
+ except MacaroonVerificationFailedException:
+ pass
+
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+ return LoginTokenAttributes(
+ user_id=user_id,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 68b446eb66..82ee11e921 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
@@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
+ auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_session_id: The session ID (sid) got from the SSO IdP.
Returns:
device id (generated if none was supplied)
"""
@@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
@@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b4ff935546..32b0254c5f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -122,9 +122,8 @@ class EventStreamHandler:
events,
time_now,
as_client_event=as_client_event,
- # We don't bundle "live" events, as otherwise clients
- # will end up double counting annotations.
- bundle_relations=False,
+ # Don't bundle aggregations as this is a deprecated API.
+ bundle_aggregations=False,
)
chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3112cc88b1..1ea837d082 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,6 +68,37 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
+ """Get joined domains from state
+
+ Args:
+ state: State map from type/state key to event.
+
+ Returns:
+ Returns a list of servers with the lowest depth of their joins.
+ Sorted by lowest depth first.
+ """
+ joined_users = [
+ (state_key, int(event.depth))
+ for (e_type, state_key), event in state.items()
+ if e_type == EventTypes.Member and event.membership == Membership.JOIN
+ ]
+
+ joined_domains: Dict[str, int] = {}
+ for u, d in joined_users:
+ try:
+ dom = get_domain_from_id(u)
+ old_d = joined_domains.get(dom)
+ if old_d:
+ joined_domains[dom] = min(d, old_d)
+ else:
+ joined_domains[dom] = d
+ except Exception:
+ pass
+
+ return sorted(joined_domains.items(), key=lambda d: d[1])
+
+
class FederationHandler:
"""Handles general incoming federation requests
@@ -268,36 +299,6 @@ class FederationHandler:
curr_state = await self.state_handler.get_current_state(room_id)
- def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
- """Get joined domains from state
-
- Args:
- state: State map from type/state key to event.
-
- Returns:
- Returns a list of servers with the lowest depth of their joins.
- Sorted by lowest depth first.
- """
- joined_users = [
- (state_key, int(event.depth))
- for (e_type, state_key), event in state.items()
- if e_type == EventTypes.Member and event.membership == Membership.JOIN
- ]
-
- joined_domains: Dict[str, int] = {}
- for u, d in joined_users:
- try:
- dom = get_domain_from_id(u)
- old_d = joined_domains.get(dom)
- if old_d:
- joined_domains[dom] = min(d, old_d)
- else:
- joined_domains[dom] = d
- except Exception:
- pass
-
- return sorted(joined_domains.items(), key=lambda d: d[1])
-
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d4e4556155..9cd21e7f2b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -165,7 +165,11 @@ class InitialSyncHandler:
invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
- invite_event, time_now, as_client_event
+ invite_event,
+ time_now,
+ # Don't bundle aggregations as this is a deprecated API.
+ bundle_aggregations=False,
+ as_client_event=as_client_event,
)
rooms_ret.append(d)
@@ -216,7 +220,11 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
await self._event_serializer.serialize_events(
- messages, time_now=time_now, as_client_event=as_client_event
+ messages,
+ time_now=time_now,
+ # Don't bundle aggregations as this is a deprecated API.
+ bundle_aggregations=False,
+ as_client_event=as_client_event,
)
),
"start": await start_token.to_string(self.store),
@@ -226,6 +234,8 @@ class InitialSyncHandler:
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
+ # Don't bundle aggregations as this is a deprecated API.
+ bundle_aggregations=False,
as_client_event=as_client_event,
)
@@ -366,14 +376,18 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
- await self._event_serializer.serialize_events(messages, time_now)
+ # Don't bundle aggregations as this is a deprecated API.
+ await self._event_serializer.serialize_events(
+ messages, time_now, bundle_aggregations=False
+ )
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
+ # Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
- room_state.values(), time_now
+ room_state.values(), time_now, bundle_aggregations=False
)
),
"presence": [],
@@ -392,8 +406,9 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
+ # Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
- current_state.values(), time_now
+ current_state.values(), time_now, bundle_aggregations=False
)
now_token = self.hs.get_event_sources().get_current_token()
@@ -467,7 +482,10 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
- await self._event_serializer.serialize_events(messages, time_now)
+ # Don't bundle aggregations as this is a deprecated API.
+ await self._event_serializer.serialize_events(
+ messages, time_now, bundle_aggregations=False
+ )
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 95b4fad3c6..87f671708c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -247,13 +247,7 @@ class MessageHandler:
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
- events = await self._event_serializer.serialize_events(
- room_state.values(),
- now,
- # We don't bother bundling aggregations in when asked for state
- # events, as clients won't use them.
- bundle_relations=False,
- )
+ events = await self._event_serializer.serialize_events(room_state.values(), now)
return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
- await p.load_jwks()
+ if not p._uses_userinfo:
+ await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
- if self._uses_userinfo:
- # We're not using jwt signing, return an empty jwk set
- return {"keys": []}
-
metadata = await self.load_metadata()
# Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
return UserInfo(resp)
- async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
@@ -673,7 +670,7 @@ class OidcProvider:
request. This value should match the one inside the token.
Returns:
- An object representing the user.
+ The decoded claims in the ID token.
"""
metadata = await self.load_metadata()
claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
- claims_cls = CodeIDToken
- else:
- claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=claims_cls,
+ claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -713,7 +707,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=claims_cls,
+ claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -721,7 +715,8 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew
- return UserInfo(claims)
+
+ return claims
async def handle_redirect_request(
self,
@@ -837,8 +832,22 @@ class OidcProvider:
logger.debug("Successfully obtained OAuth2 token data: %r", token)
- # Now that we have a token, get the userinfo, either by decoding the
- # `id_token` or by fetching the `userinfo_endpoint`.
+ # If there is an id_token, it should be validated, regardless of the
+ # userinfo endpoint is used or not.
+ if token.get("id_token") is not None:
+ try:
+ id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+ sid = id_token.get("sid")
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._sso_handler.render_error(request, "invalid_token", str(e))
+ return
+ else:
+ id_token = None
+ sid = None
+
+ # Now that we have a token, get the userinfo either from the `id_token`
+ # claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
+ elif id_token is not None:
+ userinfo = UserInfo(id_token)
else:
- try:
- userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
- except Exception as e:
- logger.exception("Invalid id_token")
- self._sso_handler.render_error(request, "invalid_token", str(e))
- return
+ logger.error("Missing id_token in token response")
+ self._sso_handler.render_error(
+ request, "invalid_token", "Missing id_token in token response"
+ )
+ return
# first check if we're doing a UIA
if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
- userinfo, token, request, session_data.client_redirect_url
+ userinfo, token, request, session_data.client_redirect_url, sid
)
except MappingException as e:
logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
token: Token,
request: SynapseRequest,
client_redirect_url: str,
+ sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow
@@ -1008,6 +1019,7 @@ class OidcProvider:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
+ auth_provider_session_id=sid,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index cd64142735..4f42438053 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -406,9 +406,6 @@ class PaginationHandler:
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
- # check we know about the room
- await self.store.get_room_version_id(room_id)
-
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3df872c578..454d06c973 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._on_shutdown,
)
- def _on_shutdown(self) -> None:
+ async def _on_shutdown(self) -> None:
if self._presence_enabled:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 448a36108e..f08a516a75 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -1,4 +1,5 @@
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -116,9 +117,13 @@ class RegistrationHandler:
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.registration.session_lifetime
+ self.nonrefreshable_access_token_lifetime = (
+ hs.config.registration.nonrefreshable_access_token_lifetime
+ )
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
+ self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
init_counters_for_auth_provider("")
@@ -741,6 +746,7 @@ class RegistrationHandler:
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
+ auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
@@ -751,9 +757,9 @@ class RegistrationHandler:
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
- auth_provider_id: The SSO IdP the user used, if any (just used for the
- prometheus metrics).
+ auth_provider_id: The SSO IdP the user used, if any.
should_issue_refresh_token: Whether it should also issue a refresh token
+ auth_provider_session_id: The session ID received during login from the SSO IdP.
Returns:
Tuple of device ID, access token, access token expiration time and refresh token
"""
@@ -764,6 +770,8 @@ class RegistrationHandler:
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
login_counter.labels(
@@ -786,6 +794,8 @@ class RegistrationHandler:
is_guest: bool = False,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
+ auth_provider_id: Optional[str] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> LoginDict:
"""Helper for register_device
@@ -793,40 +803,86 @@ class RegistrationHandler:
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker.worker_app
- valid_until_ms = None
+ now_ms = self.clock.time_msec()
+ access_token_expiry = None
if self.session_lifetime is not None:
if is_guest:
raise Exception(
"session_lifetime is not currently implemented for guest access"
)
- valid_until_ms = self.clock.time_msec() + self.session_lifetime
+ access_token_expiry = now_ms + self.session_lifetime
+
+ if self.nonrefreshable_access_token_lifetime is not None:
+ if access_token_expiry is not None:
+ # Don't allow the non-refreshable access token to outlive the
+ # session.
+ access_token_expiry = min(
+ now_ms + self.nonrefreshable_access_token_lifetime,
+ access_token_expiry,
+ )
+ else:
+ access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
refresh_token = None
refresh_token_id = None
registered_device_id = await self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
+ user_id,
+ device_id,
+ initial_display_name,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
if is_guest:
- assert valid_until_ms is None
+ assert access_token_expiry is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
if should_issue_refresh_token:
+ # A refreshable access token lifetime must be configured
+ # since we're told to issue a refresh token (the caller checks
+ # that this value is set before setting this flag).
+ assert self.refreshable_access_token_lifetime is not None
+
+ # Set the expiry time of the refreshable access token
+ access_token_expiry = now_ms + self.refreshable_access_token_lifetime
+
+ # Set the refresh token expiry time (if configured)
+ refresh_token_expiry = None
+ if self.refresh_token_lifetime is not None:
+ refresh_token_expiry = now_ms + self.refresh_token_lifetime
+
+ # Set an ultimate session expiry time (if configured)
+ ultimate_session_expiry_ts = None
+ if self.session_lifetime is not None:
+ ultimate_session_expiry_ts = now_ms + self.session_lifetime
+
+ # Also ensure that the issued tokens don't outlive the
+ # session.
+ # (It would be weird to configure a homeserver with a shorter
+ # session lifetime than token lifetime, but may as well handle
+ # it.)
+ access_token_expiry = min(
+ access_token_expiry, ultimate_session_expiry_ts
+ )
+ if refresh_token_expiry is not None:
+ refresh_token_expiry = min(
+ refresh_token_expiry, ultimate_session_expiry_ts
+ )
+
(
refresh_token,
refresh_token_id,
) = await self._auth_handler.create_refresh_token_for_user_id(
user_id,
device_id=registered_device_id,
- )
- valid_until_ms = (
- self.clock.time_msec() + self.refreshable_access_token_lifetime
+ expiry_ts=refresh_token_expiry,
+ ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
access_token = await self._auth_handler.create_access_token_for_user_id(
user_id,
device_id=registered_device_id,
- valid_until_ms=valid_until_ms,
+ valid_until_ms=access_token_expiry,
is_appservice_ghost=is_appservice_ghost,
refresh_token_id=refresh_token_id,
)
@@ -834,7 +890,7 @@ class RegistrationHandler:
return {
"device_id": registered_device_id,
"access_token": access_token,
- "valid_until_ms": valid_until_ms,
+ "valid_until_ms": access_token_expiry,
"refresh_token": refresh_token,
}
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 88053f9869..ead2198e14 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -46,6 +46,7 @@ from synapse.api.constants import (
from synapse.api.errors import (
AuthError,
Codes,
+ HttpResponseException,
LimitExceededError,
NotFoundError,
StoreError,
@@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
+from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
@@ -1220,6 +1223,147 @@ class RoomContextHandler:
return results
+class TimestampLookupHandler:
+ def __init__(self, hs: "HomeServer"):
+ self.server_name = hs.hostname
+ self.store = hs.get_datastore()
+ self.state_handler = hs.get_state_handler()
+ self.federation_client = hs.get_federation_client()
+
+ async def get_event_for_timestamp(
+ self,
+ requester: Requester,
+ room_id: str,
+ timestamp: int,
+ direction: str,
+ ) -> Tuple[str, int]:
+ """Find the closest event to the given timestamp in the given direction.
+ If we can't find an event locally or the event we have locally is next to a gap,
+ it will ask other federated homeservers for an event.
+
+ Args:
+ requester: The user making the request according to the access token
+ room_id: Room to fetch the event from
+ timestamp: The point in time (inclusive) we should navigate from in
+ the given direction to find the closest event.
+ direction: ["f"|"b"] to indicate whether we should navigate forward
+ or backward from the given timestamp to find the closest event.
+
+ Returns:
+ A tuple containing the `event_id` closest to the given timestamp in
+ the given direction and the `origin_server_ts`.
+
+ Raises:
+ SynapseError if unable to find any event locally in the given direction
+ """
+
+ local_event_id = await self.store.get_event_id_for_timestamp(
+ room_id, timestamp, direction
+ )
+ logger.debug(
+ "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
+ local_event_id,
+ timestamp,
+ )
+
+ # Check for gaps in the history where events could be hiding in between
+ # the timestamp given and the event we were able to find locally
+ is_event_next_to_backward_gap = False
+ is_event_next_to_forward_gap = False
+ if local_event_id:
+ local_event = await self.store.get_event(
+ local_event_id, allow_none=False, allow_rejected=False
+ )
+
+ if direction == "f":
+ # We only need to check for a backward gap if we're looking forwards
+ # to ensure there is nothing in between.
+ is_event_next_to_backward_gap = (
+ await self.store.is_event_next_to_backward_gap(local_event)
+ )
+ elif direction == "b":
+ # We only need to check for a forward gap if we're looking backwards
+ # to ensure there is nothing in between
+ is_event_next_to_forward_gap = (
+ await self.store.is_event_next_to_forward_gap(local_event)
+ )
+
+ # If we found a gap, we should probably ask another homeserver first
+ # about more history in between
+ if (
+ not local_event_id
+ or is_event_next_to_backward_gap
+ or is_event_next_to_forward_gap
+ ):
+ logger.debug(
+ "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
+ local_event_id,
+ timestamp,
+ )
+
+ # Find other homeservers from the given state in the room
+ curr_state = await self.state_handler.get_current_state(room_id)
+ curr_domains = get_domains_from_state(curr_state)
+ likely_domains = [
+ domain for domain, depth in curr_domains if domain != self.server_name
+ ]
+
+ # Loop through each homeserver candidate until we get a succesful response
+ for domain in likely_domains:
+ try:
+ remote_response = await self.federation_client.timestamp_to_event(
+ domain, room_id, timestamp, direction
+ )
+ logger.debug(
+ "get_event_for_timestamp: response from domain(%s)=%s",
+ domain,
+ remote_response,
+ )
+
+ # TODO: Do we want to persist this as an extremity?
+ # TODO: I think ideally, we would try to backfill from
+ # this event and run this whole
+ # `get_event_for_timestamp` function again to make sure
+ # they didn't give us an event from their gappy history.
+ remote_event_id = remote_response.event_id
+ origin_server_ts = remote_response.origin_server_ts
+
+ # Only return the remote event if it's closer than the local event
+ if not local_event or (
+ abs(origin_server_ts - timestamp)
+ < abs(local_event.origin_server_ts - timestamp)
+ ):
+ return remote_event_id, origin_server_ts
+ except (HttpResponseException, InvalidResponseError) as ex:
+ # Let's not put a high priority on some other homeserver
+ # failing to respond or giving a random response
+ logger.debug(
+ "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+ domain,
+ type(ex).__name__,
+ ex,
+ ex.args,
+ )
+ except Exception as ex:
+ # But we do want to see some exceptions in our code
+ logger.warning(
+ "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+ domain,
+ type(ex).__name__,
+ ex,
+ ex.args,
+ )
+
+ if not local_event_id:
+ raise SynapseError(
+ 404,
+ "Unable to find event from %s in direction %s" % (timestamp, direction),
+ errcode=Codes.NOT_FOUND,
+ )
+
+ return local_event_id, local_event.origin_server_ts
+
+
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -1391,20 +1535,13 @@ class RoomShutdownHandler:
await self.store.block_room(room_id, requester_user_id)
if not await self.store.get_room(room_id):
- if block:
- # We allow you to block an unknown room.
- return {
- "kicked_users": [],
- "failed_to_kick_users": [],
- "local_aliases": [],
- "new_room_id": None,
- }
- else:
- # But if you don't want to preventatively block another room,
- # this function can't do anything useful.
- raise NotFoundError(
- "Cannot shut down room: unknown room id %s" % (room_id,)
- )
+ # if we don't know about the room, there is nothing left to do.
+ return {
+ "kicked_users": [],
+ "failed_to_kick_users": [],
+ "local_aliases": [],
+ "new_room_id": None,
+ }
if new_room_user_id is not None:
if not self.hs.is_mine_id(new_room_user_id):
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 8181cc0b52..b2cfe537df 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,8 +36,9 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.types import JsonDict, Requester
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -93,6 +94,9 @@ class RoomSummaryHandler:
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
+ self._ratelimiter = Ratelimiter(
+ store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
+ )
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
@@ -249,7 +253,7 @@ class RoomSummaryHandler:
async def get_room_hierarchy(
self,
- requester: str,
+ requester: Requester,
requested_room_id: str,
suggested_only: bool = False,
max_depth: Optional[int] = None,
@@ -276,6 +280,8 @@ class RoomSummaryHandler:
Returns:
The JSON hierarchy dictionary.
"""
+ await self._ratelimiter.ratelimit(requester)
+
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
#
@@ -283,7 +289,7 @@ class RoomSummaryHandler:
# to process multiple requests for the same page will result in errors.
return await self._pagination_response_cache.wrap(
(
- requester,
+ requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,
@@ -291,7 +297,7 @@ class RoomSummaryHandler:
from_token,
),
self._get_room_hierarchy,
- requester,
+ requester.user.to_string(),
requested_room_id,
suggested_only,
max_depth,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 49fde01cf0..65c27bc64a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -365,6 +365,7 @@ class SsoHandler:
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
+ auth_provider_session_id: Optional[str] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -415,6 +416,8 @@ class SsoHandler:
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
+ auth_provider_session_id: An optional session ID from the IdP.
+
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
@@ -490,6 +493,7 @@ class SsoHandler:
client_redirect_url,
extra_login_attributes,
new_user=new_user,
+ auth_provider_session_id=auth_provider_session_id,
)
async def _call_attribute_mapper(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 891435c14d..f3039c3c3f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -334,6 +334,19 @@ class SyncHandler:
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult:
+ """The start of the machinery that produces a /sync response.
+
+ See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
+
+ This method does high-level bookkeeping:
+ - tracking the kind of sync in the logging context
+ - deleting any to_device messages whose delivery has been acknowledged.
+ - deciding if we should dispatch an instant or delayed response
+ - marking the sync as being lazily loaded, if appropriate
+
+ Computing the body of the response begins in the next method,
+ `current_sync_for_user`.
+ """
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@@ -363,7 +376,7 @@ class SyncHandler:
sync_config, since_token, full_state=full_state
)
else:
-
+ # Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
@@ -402,7 +415,12 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Get the sync for client needed to match what the server has now."""
+ """Generates the response body of a sync result, represented as a SyncResult.
+
+ This is a wrapper around `generate_sync_result` which starts an open tracing
+ span to track the sync. See `generate_sync_result` for the next part of your
+ indoctrination.
+ """
with start_active_span("current_sync_for_user"):
log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
@@ -560,7 +578,7 @@ class SyncHandler:
# that have happened since `since_key` up to `end_key`, so we
# can just use `get_room_events_stream_for_room`.
# Otherwise, we want to return the last N events in the room
- # in toplogical ordering.
+ # in topological ordering.
if since_key:
events, end_key = await self.store.get_room_events_stream_for_room(
room_id,
@@ -1042,7 +1060,18 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Generates a sync result."""
+ """Generates the response body of a sync result.
+
+ This is represented by a `SyncResult` struct, which is built from small pieces
+ using a `SyncResultBuilder`. See also
+ https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
+ the `sync_result_builder` is passed as a mutable ("inout") parameter to various
+ helper functions. These retrieve and process the data which forms the sync body,
+ often writing to the `sync_result_builder` to store their output.
+
+ At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
+ instance to signify that the sync calculation is complete.
+ """
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@@ -1344,14 +1373,22 @@ class SyncHandler:
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
) -> Dict[str, Dict[str, JsonDict]]:
- """Generates the account data portion of the sync response. Populates
- `sync_result_builder` with the result.
+ """Generates the account data portion of the sync response.
+
+ Account data (called "Client Config" in the spec) can be set either globally
+ or for a specific room. Account data consists of a list of events which
+ accumulate state, much like a room.
+
+ This function retrieves global and per-room account data. The former is written
+ to the given `sync_result_builder`. The latter is returned directly, to be
+ later written to the `sync_result_builder` on a room-by-room basis.
Args:
sync_result_builder
Returns:
- A dictionary containing the per room account data.
+ A dictionary whose keys (room ids) map to the per room account data for that
+ room.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1359,7 +1396,7 @@ class SyncHandler:
if since_token and not sync_result_builder.full_state:
(
- account_data,
+ global_account_data,
account_data_by_room,
) = await self.store.get_updated_account_data_for_user(
user_id, since_token.account_data_key
@@ -1370,23 +1407,23 @@ class SyncHandler:
)
if push_rules_changed:
- account_data["m.push_rules"] = await self.push_rules_for_user(
+ global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
(
- account_data,
+ global_account_data,
account_data_by_room,
) = await self.store.get_account_data_for_user(sync_config.user.to_string())
- account_data["m.push_rules"] = await self.push_rules_for_user(
+ global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
account_data_for_user = await sync_config.filter_collection.filter_account_data(
[
{"type": account_data_type, "content": content}
- for account_data_type, content in account_data.items()
+ for account_data_type, content in global_account_data.items()
]
)
@@ -1460,18 +1497,31 @@ class SyncHandler:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
+ In the response that reaches the client, rooms are divided into four categories:
+ `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
+ room ids returned by this function.
+
Args:
sync_result_builder
account_data_by_room: Dictionary of per room account data
Returns:
- Returns a 4-tuple of
- `(newly_joined_rooms, newly_joined_or_invited_users,
- newly_left_rooms, newly_left_users)`
+ Returns a 4-tuple describing rooms the user has joined or left, and users who've
+ joined or left rooms any rooms the user is in. This gets used later in
+ `_generate_sync_entry_for_device_list`.
+
+ Its entries are:
+ - newly_joined_rooms
+ - newly_joined_or_invited_or_knocked_users
+ - newly_left_rooms
+ - newly_left_users
"""
+ since_token = sync_result_builder.since_token
+
+ # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
- sync_result_builder.since_token is None
+ since_token is None
and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
@@ -1485,9 +1535,8 @@ class SyncHandler:
)
sync_result_builder.now_token = now_token
- # We check up front if anything has changed, if it hasn't then there is
+ # 2. We check up front if anything has changed, if it hasn't then there is
# no point in going further.
- since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = await self._have_rooms_changed(sync_result_builder)
@@ -1500,20 +1549,8 @@ class SyncHandler:
logger.debug("no-oping sync")
return set(), set(), set(), set()
- ignored_account_data = (
- await self.store.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
- )
- )
-
- # If there is ignored users account data and it matches the proper type,
- # then use it.
- ignored_users: FrozenSet[str] = frozenset()
- if ignored_account_data:
- ignored_users_data = ignored_account_data.get("ignored_users", {})
- if isinstance(ignored_users_data, dict):
- ignored_users = frozenset(ignored_users_data.keys())
-
+ # 3. Work out which rooms need reporting in the sync response.
+ ignored_users = await self._get_ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1523,7 +1560,6 @@ class SyncHandler:
)
else:
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
-
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1534,6 +1570,8 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
+ # 4. We need to apply further processing to `room_entries` (rooms considered
+ # joined or archived).
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
@@ -1552,31 +1590,13 @@ class SyncHandler:
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
- # Now we want to get any newly joined, invited or knocking users
- newly_joined_or_invited_or_knocked_users = set()
- newly_left_users = set()
- if since_token:
- for joined_sync in sync_result_builder.joined:
- it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.values()
- )
- for event in it:
- if event.type == EventTypes.Member:
- if (
- event.membership == Membership.JOIN
- or event.membership == Membership.INVITE
- or event.membership == Membership.KNOCK
- ):
- newly_joined_or_invited_or_knocked_users.add(
- event.state_key
- )
- else:
- prev_content = event.unsigned.get("prev_content", {})
- prev_membership = prev_content.get("membership", None)
- if prev_membership == Membership.JOIN:
- newly_left_users.add(event.state_key)
-
- newly_left_users -= newly_joined_or_invited_or_knocked_users
+ # 5. Work out which users have joined or left rooms we're in. We use this
+ # to build the device_list part of the sync response in
+ # `_generate_sync_entry_for_device_list`.
+ (
+ newly_joined_or_invited_or_knocked_users,
+ newly_left_users,
+ ) = sync_result_builder.calculate_user_changes()
return (
set(newly_joined_rooms),
@@ -1585,11 +1605,36 @@ class SyncHandler:
newly_left_users,
)
+ async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
+ """Retrieve the users ignored by the given user from their global account_data.
+
+ Returns an empty set if
+ - there is no global account_data entry for ignored_users
+ - there is such an entry, but it's not a JSON object.
+ """
+ # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
+ ignored_account_data = (
+ await self.store.get_global_account_data_by_type_for_user(
+ AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ )
+ )
+
+ # If there is ignored users account data and it matches the proper type,
+ # then use it.
+ ignored_users: FrozenSet[str] = frozenset()
+ if ignored_account_data:
+ ignored_users_data = ignored_account_data.get("ignored_users", {})
+ if isinstance(ignored_users_data, dict):
+ ignored_users = frozenset(ignored_users_data.keys())
+ return ignored_users
+
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
+
+ Does not modify the `sync_result_builder`.
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1597,12 +1642,13 @@ class SyncHandler:
assert since_token
- # Get a list of membership change events that have happened.
- rooms_changed = await self.store.get_membership_changes_for_user(
+ # Get a list of membership change events that have happened to the user
+ # requesting the sync.
+ membership_changes = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
- if rooms_changed:
+ if membership_changes:
return True
stream_id = since_token.room_key.stream
@@ -1614,7 +1660,25 @@ class SyncHandler:
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
- """Gets the the changes that have happened since the last sync."""
+ """Determine the changes in rooms to report to the user.
+
+ Ideally, we want to report all events whose stream ordering `s` lies in the
+ range `since_token < s <= now_token`, where the two tokens are read from the
+ sync_result_builder.
+
+ If there are too many events in that range to report, things get complicated.
+ In this situation we return a truncated list of the most recent events, and
+ indicate in the response that there is a "gap" of omitted events. Additionally:
+
+ - we include a "state_delta", to describe the changes in state over the gap,
+ - we include all membership events applying to the user making the request,
+ even those in the gap.
+
+ See the spec for the rationale:
+ https://spec.matrix.org/v1.1/client-server-api/#syncing
+
+ The sync_result_builder is not modified by this function.
+ """
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
@@ -1622,21 +1686,36 @@ class SyncHandler:
assert since_token
- # Get a list of membership change events that have happened.
- rooms_changed = await self.store.get_membership_changes_for_user(
+ # The spec
+ # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
+ # notes that membership events need special consideration:
+ #
+ # > When a sync is limited, the server MUST return membership events for events
+ # > in the gap (between since and the start of the returned timeline), regardless
+ # > as to whether or not they are redundant.
+ #
+ # We fetch such events here, but we only seem to use them for categorising rooms
+ # as newly joined, newly left, invited or knocked.
+ # TODO: we've already called this function and ran this query in
+ # _have_rooms_changed. We could keep the results in memory to avoid a
+ # second query, at the cost of more complicated source code.
+ membership_change_events = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
- for event in rooms_changed:
+ for event in membership_change_events:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
- newly_joined_rooms = []
- newly_left_rooms = []
- room_entries = []
- invited = []
- knocked = []
+ newly_joined_rooms: List[str] = []
+ newly_left_rooms: List[str] = []
+ room_entries: List[RoomSyncResultBuilder] = []
+ invited: List[InvitedSyncResult] = []
+ knocked: List[KnockedSyncResult] = []
for room_id, events in mem_change_events_by_room_id.items():
+ # The body of this loop will add this room to at least one of the five lists
+ # above. Things get messy if you've e.g. joined, left, joined then left the
+ # room all in the same sync period.
logger.debug(
"Membership changes in %s: [%s]",
room_id,
@@ -1691,6 +1770,7 @@ class SyncHandler:
if not non_joins:
continue
+ last_non_join = non_joins[-1]
# Check if we have left the room. This can either be because we were
# joined before *or* that we since joined and then left.
@@ -1712,18 +1792,18 @@ class SyncHandler:
newly_left_rooms.append(room_id)
# Only bother if we're still currently invited
- should_invite = non_joins[-1].membership == Membership.INVITE
+ should_invite = last_non_join.membership == Membership.INVITE
if should_invite:
- if event.sender not in ignored_users:
- invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+ if last_non_join.sender not in ignored_users:
+ invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
if invite_room_sync:
invited.append(invite_room_sync)
# Only bother if our latest membership in the room is knock (and we haven't
# been accepted/rejected in the meantime).
- should_knock = non_joins[-1].membership == Membership.KNOCK
+ should_knock = last_non_join.membership == Membership.KNOCK
if should_knock:
- knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+ knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join)
if knock_room_sync:
knocked.append(knock_room_sync)
@@ -1781,7 +1861,9 @@ class SyncHandler:
timeline_limit = sync_config.filter_collection.timeline_limit()
- # Get all events for rooms we're currently joined to.
+ # Get all events since the `from_key` in rooms we're currently joined to.
+ # If there are too many, we get the most recent events only. This leaves
+ # a "gap" in the timeline, as described by the spec for /sync.
room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
@@ -1842,6 +1924,10 @@ class SyncHandler:
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
+ Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
+
+ This function does not modify the sync_result_builder.
+
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
@@ -1853,16 +1939,9 @@ class SyncHandler:
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
- membership_list = (
- Membership.INVITE,
- Membership.KNOCK,
- Membership.JOIN,
- Membership.LEAVE,
- Membership.BAN,
- )
-
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
- user_id=user_id, membership_list=membership_list
+ user_id=user_id,
+ membership_list=Membership.LIST,
)
room_entries = []
@@ -2212,8 +2291,7 @@ def _calculate_state(
# to only include membership events for the senders in the timeline.
# In practice, we can do this by removing them from the p_ids list,
# which is the list of relevant state we know we have already sent to the client.
- # see https://github.com/matrix-org/synapse/pull/2970
- # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
+ # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
p_ids.difference_update(
@@ -2262,6 +2340,39 @@ class SyncResultBuilder:
groups: Optional[GroupsSyncResult] = None
to_device: List[JsonDict] = attr.Factory(list)
+ def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
+ """Work out which other users have joined or left rooms we are joined to.
+
+ This data only is only useful for an incremental sync.
+
+ The SyncResultBuilder is not modified by this function.
+ """
+ newly_joined_or_invited_or_knocked_users = set()
+ newly_left_users = set()
+ if self.since_token:
+ for joined_sync in self.joined:
+ it = itertools.chain(
+ joined_sync.timeline.events, joined_sync.state.values()
+ )
+ for event in it:
+ if event.type == EventTypes.Member:
+ if (
+ event.membership == Membership.JOIN
+ or event.membership == Membership.INVITE
+ or event.membership == Membership.KNOCK
+ ):
+ newly_joined_or_invited_or_knocked_users.add(
+ event.state_key
+ )
+ else:
+ prev_content = event.unsigned.get("prev_content", {})
+ prev_membership = prev_content.get("membership", None)
+ if prev_membership == Membership.JOIN:
+ newly_left_users.add(event.state_key)
+
+ newly_left_users -= newly_joined_or_invited_or_knocked_users
+ return newly_joined_or_invited_or_knocked_users, newly_left_users
+
@attr.s(slots=True, auto_attribs=True)
class RoomSyncResultBuilder:
|