diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 61607cf2ba..4b66a9862f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,7 +18,6 @@ import time
import unicodedata
import urllib.parse
from binascii import crc32
-from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -39,7 +38,6 @@ import attr
import bcrypt
import pymacaroons
import unpaddedbase64
-from pymacaroons.exceptions import MacaroonVerificationFailedException
from twisted.web.server import Request
@@ -183,11 +181,8 @@ class LoginTokenAttributes:
user_id = attr.ib(type=str)
+ # the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
- """The SSO Identity Provider that the user authenticated with, to get this token."""
-
- auth_provider_session_id = attr.ib(type=Optional[str])
- """The session ID advertised by the SSO Identity Provider."""
class AuthHandler:
@@ -761,109 +756,53 @@ class AuthHandler:
async def refresh_token(
self,
refresh_token: str,
- access_token_valid_until_ms: Optional[int],
- refresh_token_valid_until_ms: Optional[int],
- ) -> Tuple[str, str, Optional[int]]:
+ valid_until_ms: Optional[int],
+ ) -> Tuple[str, str]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
- The lifetime of both the access token and refresh token will be capped so that they
- do not exceed the session's ultimate expiry time, if applicable.
-
Args:
refresh_token: The token to consume.
- access_token_valid_until_ms: The expiration timestamp of the new access token.
- None if the access token does not expire.
- refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
- None if the refresh token does not expire.
+ valid_until_ms: The expiration timestamp of the new access token.
+
Returns:
- A tuple containing:
- - the new access token
- - the new refresh token
- - the actual expiry time of the access token, which may be earlier than
- `access_token_valid_until_ms`.
+ A tuple containing the new access token and refresh token
"""
# Verify the token signature first before looking up the token
if not self._verify_refresh_token(refresh_token):
- raise SynapseError(
- HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
- )
+ raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
existing_token = await self.store.lookup_refresh_token(refresh_token)
if existing_token is None:
- raise SynapseError(
- HTTPStatus.UNAUTHORIZED,
- "refresh token does not exist",
- Codes.UNKNOWN_TOKEN,
- )
+ raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
if (
existing_token.has_next_access_token_been_used
or existing_token.has_next_refresh_token_been_refreshed
):
raise SynapseError(
- HTTPStatus.FORBIDDEN,
- "refresh token isn't valid anymore",
- Codes.FORBIDDEN,
+ 403, "refresh token isn't valid anymore", Codes.FORBIDDEN
)
- now_ms = self._clock.time_msec()
-
- if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
-
- raise SynapseError(
- HTTPStatus.FORBIDDEN,
- "The supplied refresh token has expired",
- Codes.FORBIDDEN,
- )
-
- if existing_token.ultimate_session_expiry_ts is not None:
- # This session has a bounded lifetime, even across refreshes.
-
- if access_token_valid_until_ms is not None:
- access_token_valid_until_ms = min(
- access_token_valid_until_ms,
- existing_token.ultimate_session_expiry_ts,
- )
- else:
- access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
-
- if refresh_token_valid_until_ms is not None:
- refresh_token_valid_until_ms = min(
- refresh_token_valid_until_ms,
- existing_token.ultimate_session_expiry_ts,
- )
- else:
- refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
- if existing_token.ultimate_session_expiry_ts < now_ms:
- raise SynapseError(
- HTTPStatus.FORBIDDEN,
- "The session has expired and can no longer be refreshed",
- Codes.FORBIDDEN,
- )
-
(
new_refresh_token,
new_refresh_token_id,
) = await self.create_refresh_token_for_user_id(
- user_id=existing_token.user_id,
- device_id=existing_token.device_id,
- expiry_ts=refresh_token_valid_until_ms,
- ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
+ user_id=existing_token.user_id, device_id=existing_token.device_id
)
access_token = await self.create_access_token_for_user_id(
user_id=existing_token.user_id,
device_id=existing_token.device_id,
- valid_until_ms=access_token_valid_until_ms,
+ valid_until_ms=valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.replace_refresh_token(
existing_token.token_id, new_refresh_token_id
)
- return access_token, new_refresh_token, access_token_valid_until_ms
+ return access_token, new_refresh_token
def _verify_refresh_token(self, token: str) -> bool:
"""
@@ -897,8 +836,6 @@ class AuthHandler:
self,
user_id: str,
device_id: str,
- expiry_ts: Optional[int],
- ultimate_session_expiry_ts: Optional[int],
) -> Tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@@ -906,13 +843,6 @@ class AuthHandler:
Args:
user_id: canonical user ID
device_id: the device ID to associate with the token.
- expiry_ts (milliseconds since the epoch): Time after which the
- refresh token cannot be used.
- If None, the refresh token never expires until it has been used.
- ultimate_session_expiry_ts (milliseconds since the epoch):
- Time at which the session will end and can not be extended any
- further.
- If None, the session can be refreshed indefinitely.
Returns:
The newly created refresh token and its ID in the database
@@ -922,8 +852,6 @@ class AuthHandler:
user_id=user_id,
token=refresh_token,
device_id=device_id,
- expiry_ts=expiry_ts,
- ultimate_session_expiry_ts=ultimate_session_expiry_ts,
)
return refresh_token, refresh_token_id
@@ -1654,7 +1582,6 @@ class AuthHandler:
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
- auth_provider_session_id: Optional[str] = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1670,7 +1597,6 @@ class AuthHandler:
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
- auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1691,7 +1617,6 @@ class AuthHandler:
extra_attributes,
new_user=new_user,
user_profile_data=profile,
- auth_provider_session_id=auth_provider_session_id,
)
def _complete_sso_login(
@@ -1703,7 +1628,6 @@ class AuthHandler:
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
- auth_provider_session_id: Optional[str] = None,
) -> None:
"""
The synchronous portion of complete_sso_login.
@@ -1725,9 +1649,7 @@ class AuthHandler:
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
+ registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1832,7 +1754,6 @@ class MacaroonGenerator:
self,
user_id: str,
auth_provider_id: str,
- auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
@@ -1841,10 +1762,6 @@ class MacaroonGenerator:
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
- if auth_provider_session_id is not None:
- macaroon.add_first_party_caveat(
- "auth_provider_session_id = %s" % (auth_provider_session_id,)
- )
return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1866,28 +1783,15 @@ class MacaroonGenerator:
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
- auth_provider_session_id: Optional[str] = None
- try:
- auth_provider_session_id = get_value_from_macaroon(
- macaroon, "auth_provider_session_id"
- )
- except MacaroonVerificationFailedException:
- pass
-
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
- v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- return LoginTokenAttributes(
- user_id=user_id,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
- )
+ return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 82ee11e921..68b446eb66 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -301,8 +301,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
- auth_provider_id: Optional[str] = None,
- auth_provider_session_id: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
@@ -314,8 +312,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
- auth_provider_id: The SSO IdP the user used, if any.
- auth_provider_session_id: The session ID (sid) got from the SSO IdP.
Returns:
device id (generated if none was supplied)
"""
@@ -327,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
@@ -343,8 +337,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 32b0254c5f..b4ff935546 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -122,8 +122,9 @@ class EventStreamHandler:
events,
time_now,
as_client_event=as_client_event,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
+ # We don't bundle "live" events, as otherwise clients
+ # will end up double counting annotations.
+ bundle_relations=False,
)
chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1ea837d082..3112cc88b1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,37 +68,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
- """Get joined domains from state
-
- Args:
- state: State map from type/state key to event.
-
- Returns:
- Returns a list of servers with the lowest depth of their joins.
- Sorted by lowest depth first.
- """
- joined_users = [
- (state_key, int(event.depth))
- for (e_type, state_key), event in state.items()
- if e_type == EventTypes.Member and event.membership == Membership.JOIN
- ]
-
- joined_domains: Dict[str, int] = {}
- for u, d in joined_users:
- try:
- dom = get_domain_from_id(u)
- old_d = joined_domains.get(dom)
- if old_d:
- joined_domains[dom] = min(d, old_d)
- else:
- joined_domains[dom] = d
- except Exception:
- pass
-
- return sorted(joined_domains.items(), key=lambda d: d[1])
-
-
class FederationHandler:
"""Handles general incoming federation requests
@@ -299,6 +268,36 @@ class FederationHandler:
curr_state = await self.state_handler.get_current_state(room_id)
+ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
+ """Get joined domains from state
+
+ Args:
+ state: State map from type/state key to event.
+
+ Returns:
+ Returns a list of servers with the lowest depth of their joins.
+ Sorted by lowest depth first.
+ """
+ joined_users = [
+ (state_key, int(event.depth))
+ for (e_type, state_key), event in state.items()
+ if e_type == EventTypes.Member and event.membership == Membership.JOIN
+ ]
+
+ joined_domains: Dict[str, int] = {}
+ for u, d in joined_users:
+ try:
+ dom = get_domain_from_id(u)
+ old_d = joined_domains.get(dom)
+ if old_d:
+ joined_domains[dom] = min(d, old_d)
+ else:
+ joined_domains[dom] = d
+ except Exception:
+ pass
+
+ return sorted(joined_domains.items(), key=lambda d: d[1])
+
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9cd21e7f2b..d4e4556155 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -165,11 +165,7 @@ class InitialSyncHandler:
invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
- invite_event,
- time_now,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
- as_client_event=as_client_event,
+ invite_event, time_now, as_client_event
)
rooms_ret.append(d)
@@ -220,11 +216,7 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
await self._event_serializer.serialize_events(
- messages,
- time_now=time_now,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
- as_client_event=as_client_event,
+ messages, time_now=time_now, as_client_event=as_client_event
)
),
"start": await start_token.to_string(self.store),
@@ -234,8 +226,6 @@ class InitialSyncHandler:
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
- # Don't bundle aggregations as this is a deprecated API.
- bundle_aggregations=False,
as_client_event=as_client_event,
)
@@ -376,18 +366,14 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
- # Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(
- messages, time_now, bundle_aggregations=False
- )
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
- # Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
- room_state.values(), time_now, bundle_aggregations=False
+ room_state.values(), time_now
)
),
"presence": [],
@@ -406,9 +392,8 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
- # Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
- current_state.values(), time_now, bundle_aggregations=False
+ current_state.values(), time_now
)
now_token = self.hs.get_event_sources().get_current_token()
@@ -482,10 +467,7 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
- # Don't bundle aggregations as this is a deprecated API.
- await self._event_serializer.serialize_events(
- messages, time_now, bundle_aggregations=False
- )
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 87f671708c..95b4fad3c6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -247,7 +247,13 @@ class MessageHandler:
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
- events = await self._event_serializer.serialize_events(room_state.values(), now)
+ events = await self._event_serializer.serialize_events(
+ room_state.values(),
+ now,
+ # We don't bother bundling aggregations in when asked for state
+ # events, as clients won't use them.
+ bundle_relations=False,
+ )
return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index deb3539751..3665d91513 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
@@ -117,8 +117,7 @@ class OidcHandler:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
- if not p._uses_userinfo:
- await p.load_jwks()
+ await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
@@ -499,6 +498,10 @@ class OidcProvider:
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
+ if self._uses_userinfo:
+ # We're not using jwt signing, return an empty jwk set
+ return {"keys": []}
+
metadata = await self.load_metadata()
# Load the JWKS using the `jwks_uri` metadata.
@@ -660,7 +663,7 @@ class OidcProvider:
return UserInfo(resp)
- async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
+ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
@@ -670,7 +673,7 @@ class OidcProvider:
request. This value should match the one inside the token.
Returns:
- The decoded claims in the ID token.
+ An object representing the user.
"""
metadata = await self.load_metadata()
claims_params = {
@@ -681,6 +684,9 @@ class OidcProvider:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
+ claims_cls = CodeIDToken
+ else:
+ claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
@@ -697,7 +703,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=CodeIDToken,
+ claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -707,7 +713,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=CodeIDToken,
+ claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -715,8 +721,7 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew
-
- return claims
+ return UserInfo(claims)
async def handle_redirect_request(
self,
@@ -832,22 +837,8 @@ class OidcProvider:
logger.debug("Successfully obtained OAuth2 token data: %r", token)
- # If there is an id_token, it should be validated, regardless of the
- # userinfo endpoint is used or not.
- if token.get("id_token") is not None:
- try:
- id_token = await self._parse_id_token(token, nonce=session_data.nonce)
- sid = id_token.get("sid")
- except Exception as e:
- logger.exception("Invalid id_token")
- self._sso_handler.render_error(request, "invalid_token", str(e))
- return
- else:
- id_token = None
- sid = None
-
- # Now that we have a token, get the userinfo either from the `id_token`
- # claims or by fetching the `userinfo_endpoint`.
+ # Now that we have a token, get the userinfo, either by decoding the
+ # `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
@@ -855,14 +846,13 @@ class OidcProvider:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
- elif id_token is not None:
- userinfo = UserInfo(id_token)
else:
- logger.error("Missing id_token in token response")
- self._sso_handler.render_error(
- request, "invalid_token", "Missing id_token in token response"
- )
- return
+ try:
+ userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._sso_handler.render_error(request, "invalid_token", str(e))
+ return
# first check if we're doing a UIA
if session_data.ui_auth_session_id:
@@ -894,7 +884,7 @@ class OidcProvider:
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
- userinfo, token, request, session_data.client_redirect_url, sid
+ userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
@@ -906,7 +896,6 @@ class OidcProvider:
token: Token,
request: SynapseRequest,
client_redirect_url: str,
- sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow
@@ -1019,7 +1008,6 @@ class OidcProvider:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
- auth_provider_session_id=sid,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 4f42438053..cd64142735 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -406,6 +406,9 @@ class PaginationHandler:
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
+ # check we know about the room
+ await self.store.get_room_version_id(room_id)
+
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 454d06c973..3df872c578 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._on_shutdown,
)
- async def _on_shutdown(self) -> None:
+ def _on_shutdown(self) -> None:
if self._presence_enabled:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f08a516a75..448a36108e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -1,5 +1,4 @@
# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -117,13 +116,9 @@ class RegistrationHandler:
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.registration.session_lifetime
- self.nonrefreshable_access_token_lifetime = (
- hs.config.registration.nonrefreshable_access_token_lifetime
- )
self.refreshable_access_token_lifetime = (
hs.config.registration.refreshable_access_token_lifetime
)
- self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
init_counters_for_auth_provider("")
@@ -746,7 +741,6 @@ class RegistrationHandler:
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
- auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
@@ -757,9 +751,9 @@ class RegistrationHandler:
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
- auth_provider_id: The SSO IdP the user used, if any.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
should_issue_refresh_token: Whether it should also issue a refresh token
- auth_provider_session_id: The session ID received during login from the SSO IdP.
Returns:
Tuple of device ID, access token, access token expiration time and refresh token
"""
@@ -770,8 +764,6 @@ class RegistrationHandler:
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
)
login_counter.labels(
@@ -794,8 +786,6 @@ class RegistrationHandler:
is_guest: bool = False,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
- auth_provider_id: Optional[str] = None,
- auth_provider_session_id: Optional[str] = None,
) -> LoginDict:
"""Helper for register_device
@@ -803,86 +793,40 @@ class RegistrationHandler:
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker.worker_app
- now_ms = self.clock.time_msec()
- access_token_expiry = None
+ valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
raise Exception(
"session_lifetime is not currently implemented for guest access"
)
- access_token_expiry = now_ms + self.session_lifetime
-
- if self.nonrefreshable_access_token_lifetime is not None:
- if access_token_expiry is not None:
- # Don't allow the non-refreshable access token to outlive the
- # session.
- access_token_expiry = min(
- now_ms + self.nonrefreshable_access_token_lifetime,
- access_token_expiry,
- )
- else:
- access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
+ valid_until_ms = self.clock.time_msec() + self.session_lifetime
refresh_token = None
refresh_token_id = None
registered_device_id = await self.device_handler.check_device_registered(
- user_id,
- device_id,
- initial_display_name,
- auth_provider_id=auth_provider_id,
- auth_provider_session_id=auth_provider_session_id,
+ user_id, device_id, initial_display_name
)
if is_guest:
- assert access_token_expiry is None
+ assert valid_until_ms is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
if should_issue_refresh_token:
- # A refreshable access token lifetime must be configured
- # since we're told to issue a refresh token (the caller checks
- # that this value is set before setting this flag).
- assert self.refreshable_access_token_lifetime is not None
-
- # Set the expiry time of the refreshable access token
- access_token_expiry = now_ms + self.refreshable_access_token_lifetime
-
- # Set the refresh token expiry time (if configured)
- refresh_token_expiry = None
- if self.refresh_token_lifetime is not None:
- refresh_token_expiry = now_ms + self.refresh_token_lifetime
-
- # Set an ultimate session expiry time (if configured)
- ultimate_session_expiry_ts = None
- if self.session_lifetime is not None:
- ultimate_session_expiry_ts = now_ms + self.session_lifetime
-
- # Also ensure that the issued tokens don't outlive the
- # session.
- # (It would be weird to configure a homeserver with a shorter
- # session lifetime than token lifetime, but may as well handle
- # it.)
- access_token_expiry = min(
- access_token_expiry, ultimate_session_expiry_ts
- )
- if refresh_token_expiry is not None:
- refresh_token_expiry = min(
- refresh_token_expiry, ultimate_session_expiry_ts
- )
-
(
refresh_token,
refresh_token_id,
) = await self._auth_handler.create_refresh_token_for_user_id(
user_id,
device_id=registered_device_id,
- expiry_ts=refresh_token_expiry,
- ultimate_session_expiry_ts=ultimate_session_expiry_ts,
+ )
+ valid_until_ms = (
+ self.clock.time_msec() + self.refreshable_access_token_lifetime
)
access_token = await self._auth_handler.create_access_token_for_user_id(
user_id,
device_id=registered_device_id,
- valid_until_ms=access_token_expiry,
+ valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
refresh_token_id=refresh_token_id,
)
@@ -890,7 +834,7 @@ class RegistrationHandler:
return {
"device_id": registered_device_id,
"access_token": access_token,
- "valid_until_ms": access_token_expiry,
+ "valid_until_ms": valid_until_ms,
"refresh_token": refresh_token,
}
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ead2198e14..88053f9869 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -46,7 +46,6 @@ from synapse.api.constants import (
from synapse.api.errors import (
AuthError,
Codes,
- HttpResponseException,
LimitExceededError,
NotFoundError,
StoreError,
@@ -57,8 +56,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
-from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
@@ -1223,147 +1220,6 @@ class RoomContextHandler:
return results
-class TimestampLookupHandler:
- def __init__(self, hs: "HomeServer"):
- self.server_name = hs.hostname
- self.store = hs.get_datastore()
- self.state_handler = hs.get_state_handler()
- self.federation_client = hs.get_federation_client()
-
- async def get_event_for_timestamp(
- self,
- requester: Requester,
- room_id: str,
- timestamp: int,
- direction: str,
- ) -> Tuple[str, int]:
- """Find the closest event to the given timestamp in the given direction.
- If we can't find an event locally or the event we have locally is next to a gap,
- it will ask other federated homeservers for an event.
-
- Args:
- requester: The user making the request according to the access token
- room_id: Room to fetch the event from
- timestamp: The point in time (inclusive) we should navigate from in
- the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
- or backward from the given timestamp to find the closest event.
-
- Returns:
- A tuple containing the `event_id` closest to the given timestamp in
- the given direction and the `origin_server_ts`.
-
- Raises:
- SynapseError if unable to find any event locally in the given direction
- """
-
- local_event_id = await self.store.get_event_id_for_timestamp(
- room_id, timestamp, direction
- )
- logger.debug(
- "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
- local_event_id,
- timestamp,
- )
-
- # Check for gaps in the history where events could be hiding in between
- # the timestamp given and the event we were able to find locally
- is_event_next_to_backward_gap = False
- is_event_next_to_forward_gap = False
- if local_event_id:
- local_event = await self.store.get_event(
- local_event_id, allow_none=False, allow_rejected=False
- )
-
- if direction == "f":
- # We only need to check for a backward gap if we're looking forwards
- # to ensure there is nothing in between.
- is_event_next_to_backward_gap = (
- await self.store.is_event_next_to_backward_gap(local_event)
- )
- elif direction == "b":
- # We only need to check for a forward gap if we're looking backwards
- # to ensure there is nothing in between
- is_event_next_to_forward_gap = (
- await self.store.is_event_next_to_forward_gap(local_event)
- )
-
- # If we found a gap, we should probably ask another homeserver first
- # about more history in between
- if (
- not local_event_id
- or is_event_next_to_backward_gap
- or is_event_next_to_forward_gap
- ):
- logger.debug(
- "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
- local_event_id,
- timestamp,
- )
-
- # Find other homeservers from the given state in the room
- curr_state = await self.state_handler.get_current_state(room_id)
- curr_domains = get_domains_from_state(curr_state)
- likely_domains = [
- domain for domain, depth in curr_domains if domain != self.server_name
- ]
-
- # Loop through each homeserver candidate until we get a succesful response
- for domain in likely_domains:
- try:
- remote_response = await self.federation_client.timestamp_to_event(
- domain, room_id, timestamp, direction
- )
- logger.debug(
- "get_event_for_timestamp: response from domain(%s)=%s",
- domain,
- remote_response,
- )
-
- # TODO: Do we want to persist this as an extremity?
- # TODO: I think ideally, we would try to backfill from
- # this event and run this whole
- # `get_event_for_timestamp` function again to make sure
- # they didn't give us an event from their gappy history.
- remote_event_id = remote_response.event_id
- origin_server_ts = remote_response.origin_server_ts
-
- # Only return the remote event if it's closer than the local event
- if not local_event or (
- abs(origin_server_ts - timestamp)
- < abs(local_event.origin_server_ts - timestamp)
- ):
- return remote_event_id, origin_server_ts
- except (HttpResponseException, InvalidResponseError) as ex:
- # Let's not put a high priority on some other homeserver
- # failing to respond or giving a random response
- logger.debug(
- "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
- domain,
- type(ex).__name__,
- ex,
- ex.args,
- )
- except Exception as ex:
- # But we do want to see some exceptions in our code
- logger.warning(
- "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
- domain,
- type(ex).__name__,
- ex,
- ex.args,
- )
-
- if not local_event_id:
- raise SynapseError(
- 404,
- "Unable to find event from %s in direction %s" % (timestamp, direction),
- errcode=Codes.NOT_FOUND,
- )
-
- return local_event_id, local_event.origin_server_ts
-
-
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -1535,13 +1391,20 @@ class RoomShutdownHandler:
await self.store.block_room(room_id, requester_user_id)
if not await self.store.get_room(room_id):
- # if we don't know about the room, there is nothing left to do.
- return {
- "kicked_users": [],
- "failed_to_kick_users": [],
- "local_aliases": [],
- "new_room_id": None,
- }
+ if block:
+ # We allow you to block an unknown room.
+ return {
+ "kicked_users": [],
+ "failed_to_kick_users": [],
+ "local_aliases": [],
+ "new_room_id": None,
+ }
+ else:
+ # But if you don't want to preventatively block another room,
+ # this function can't do anything useful.
+ raise NotFoundError(
+ "Cannot shut down room: unknown room id %s" % (room_id,)
+ )
if new_room_user_id is not None:
if not self.hs.is_mine_id(new_room_user_id):
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index b2cfe537df..8181cc0b52 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,9 +36,8 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
-from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -94,9 +93,6 @@ class RoomSummaryHandler:
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
- self._ratelimiter = Ratelimiter(
- store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
- )
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
@@ -253,7 +249,7 @@ class RoomSummaryHandler:
async def get_room_hierarchy(
self,
- requester: Requester,
+ requester: str,
requested_room_id: str,
suggested_only: bool = False,
max_depth: Optional[int] = None,
@@ -280,8 +276,6 @@ class RoomSummaryHandler:
Returns:
The JSON hierarchy dictionary.
"""
- await self._ratelimiter.ratelimit(requester)
-
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
#
@@ -289,7 +283,7 @@ class RoomSummaryHandler:
# to process multiple requests for the same page will result in errors.
return await self._pagination_response_cache.wrap(
(
- requester.user.to_string(),
+ requester,
requested_room_id,
suggested_only,
max_depth,
@@ -297,7 +291,7 @@ class RoomSummaryHandler:
from_token,
),
self._get_room_hierarchy,
- requester.user.to_string(),
+ requester,
requested_room_id,
suggested_only,
max_depth,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 65c27bc64a..49fde01cf0 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -365,7 +365,6 @@ class SsoHandler:
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
- auth_provider_session_id: Optional[str] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -416,8 +415,6 @@ class SsoHandler:
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
- auth_provider_session_id: An optional session ID from the IdP.
-
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
@@ -493,7 +490,6 @@ class SsoHandler:
client_redirect_url,
extra_login_attributes,
new_user=new_user,
- auth_provider_session_id=auth_provider_session_id,
)
async def _call_attribute_mapper(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f3039c3c3f..891435c14d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -334,19 +334,6 @@ class SyncHandler:
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult:
- """The start of the machinery that produces a /sync response.
-
- See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
-
- This method does high-level bookkeeping:
- - tracking the kind of sync in the logging context
- - deleting any to_device messages whose delivery has been acknowledged.
- - deciding if we should dispatch an instant or delayed response
- - marking the sync as being lazily loaded, if appropriate
-
- Computing the body of the response begins in the next method,
- `current_sync_for_user`.
- """
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@@ -376,7 +363,7 @@ class SyncHandler:
sync_config, since_token, full_state=full_state
)
else:
- # Otherwise, we wait for something to happen and report it to the user.
+
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
@@ -415,12 +402,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Generates the response body of a sync result, represented as a SyncResult.
-
- This is a wrapper around `generate_sync_result` which starts an open tracing
- span to track the sync. See `generate_sync_result` for the next part of your
- indoctrination.
- """
+ """Get the sync for client needed to match what the server has now."""
with start_active_span("current_sync_for_user"):
log_kv({"since_token": since_token})
sync_result = await self.generate_sync_result(
@@ -578,7 +560,7 @@ class SyncHandler:
# that have happened since `since_key` up to `end_key`, so we
# can just use `get_room_events_stream_for_room`.
# Otherwise, we want to return the last N events in the room
- # in topological ordering.
+ # in toplogical ordering.
if since_key:
events, end_key = await self.store.get_room_events_stream_for_room(
room_id,
@@ -1060,18 +1042,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
- """Generates the response body of a sync result.
-
- This is represented by a `SyncResult` struct, which is built from small pieces
- using a `SyncResultBuilder`. See also
- https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
- the `sync_result_builder` is passed as a mutable ("inout") parameter to various
- helper functions. These retrieve and process the data which forms the sync body,
- often writing to the `sync_result_builder` to store their output.
-
- At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
- instance to signify that the sync calculation is complete.
- """
+ """Generates a sync result."""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@@ -1373,22 +1344,14 @@ class SyncHandler:
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
) -> Dict[str, Dict[str, JsonDict]]:
- """Generates the account data portion of the sync response.
-
- Account data (called "Client Config" in the spec) can be set either globally
- or for a specific room. Account data consists of a list of events which
- accumulate state, much like a room.
-
- This function retrieves global and per-room account data. The former is written
- to the given `sync_result_builder`. The latter is returned directly, to be
- later written to the `sync_result_builder` on a room-by-room basis.
+ """Generates the account data portion of the sync response. Populates
+ `sync_result_builder` with the result.
Args:
sync_result_builder
Returns:
- A dictionary whose keys (room ids) map to the per room account data for that
- room.
+ A dictionary containing the per room account data.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1396,7 +1359,7 @@ class SyncHandler:
if since_token and not sync_result_builder.full_state:
(
- global_account_data,
+ account_data,
account_data_by_room,
) = await self.store.get_updated_account_data_for_user(
user_id, since_token.account_data_key
@@ -1407,23 +1370,23 @@ class SyncHandler:
)
if push_rules_changed:
- global_account_data["m.push_rules"] = await self.push_rules_for_user(
+ account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
(
- global_account_data,
+ account_data,
account_data_by_room,
) = await self.store.get_account_data_for_user(sync_config.user.to_string())
- global_account_data["m.push_rules"] = await self.push_rules_for_user(
+ account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
account_data_for_user = await sync_config.filter_collection.filter_account_data(
[
{"type": account_data_type, "content": content}
- for account_data_type, content in global_account_data.items()
+ for account_data_type, content in account_data.items()
]
)
@@ -1497,31 +1460,18 @@ class SyncHandler:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
- In the response that reaches the client, rooms are divided into four categories:
- `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
- room ids returned by this function.
-
Args:
sync_result_builder
account_data_by_room: Dictionary of per room account data
Returns:
- Returns a 4-tuple describing rooms the user has joined or left, and users who've
- joined or left rooms any rooms the user is in. This gets used later in
- `_generate_sync_entry_for_device_list`.
-
- Its entries are:
- - newly_joined_rooms
- - newly_joined_or_invited_or_knocked_users
- - newly_left_rooms
- - newly_left_users
+ Returns a 4-tuple of
+ `(newly_joined_rooms, newly_joined_or_invited_users,
+ newly_left_rooms, newly_left_users)`
"""
- since_token = sync_result_builder.since_token
-
- # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
- since_token is None
+ sync_result_builder.since_token is None
and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
@@ -1535,8 +1485,9 @@ class SyncHandler:
)
sync_result_builder.now_token = now_token
- # 2. We check up front if anything has changed, if it hasn't then there is
+ # We check up front if anything has changed, if it hasn't then there is
# no point in going further.
+ since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = await self._have_rooms_changed(sync_result_builder)
@@ -1549,8 +1500,20 @@ class SyncHandler:
logger.debug("no-oping sync")
return set(), set(), set(), set()
- # 3. Work out which rooms need reporting in the sync response.
- ignored_users = await self._get_ignored_users(user_id)
+ ignored_account_data = (
+ await self.store.get_global_account_data_by_type_for_user(
+ AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+ )
+ )
+
+ # If there is ignored users account data and it matches the proper type,
+ # then use it.
+ ignored_users: FrozenSet[str] = frozenset()
+ if ignored_account_data:
+ ignored_users_data = ignored_account_data.get("ignored_users", {})
+ if isinstance(ignored_users_data, dict):
+ ignored_users = frozenset(ignored_users_data.keys())
+
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1560,6 +1523,7 @@ class SyncHandler:
)
else:
room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
+
tags_by_room = await self.store.get_tags_for_user(user_id)
log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1570,8 +1534,6 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
- # 4. We need to apply further processing to `room_entries` (rooms considered
- # joined or archived).
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
@@ -1590,13 +1552,31 @@ class SyncHandler:
sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
- # 5. Work out which users have joined or left rooms we're in. We use this
- # to build the device_list part of the sync response in
- # `_generate_sync_entry_for_device_list`.
- (
- newly_joined_or_invited_or_knocked_users,
- newly_left_users,
- ) = sync_result_builder.calculate_user_changes()
+ # Now we want to get any newly joined, invited or knocking users
+ newly_joined_or_invited_or_knocked_users = set()
+ newly_left_users = set()
+ if since_token:
+ for joined_sync in sync_result_builder.joined:
+ it = itertools.chain(
+ joined_sync.timeline.events, joined_sync.state.values()
+ )
+ for event in it:
+ if event.type == EventTypes.Member:
+ if (
+ event.membership == Membership.JOIN
+ or event.membership == Membership.INVITE
+ or event.membership == Membership.KNOCK
+ ):
+ newly_joined_or_invited_or_knocked_users.add(
+ event.state_key
+ )
+ else:
+ prev_content = event.unsigned.get("prev_content", {})
+ prev_membership = prev_content.get("membership", None)
+ if prev_membership == Membership.JOIN:
+ newly_left_users.add(event.state_key)
+
+ newly_left_users -= newly_joined_or_invited_or_knocked_users
return (
set(newly_joined_rooms),
@@ -1605,36 +1585,11 @@ class SyncHandler:
newly_left_users,
)
- async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
- """Retrieve the users ignored by the given user from their global account_data.
-
- Returns an empty set if
- - there is no global account_data entry for ignored_users
- - there is such an entry, but it's not a JSON object.
- """
- # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
- ignored_account_data = (
- await self.store.get_global_account_data_by_type_for_user(
- AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
- )
- )
-
- # If there is ignored users account data and it matches the proper type,
- # then use it.
- ignored_users: FrozenSet[str] = frozenset()
- if ignored_account_data:
- ignored_users_data = ignored_account_data.get("ignored_users", {})
- if isinstance(ignored_users_data, dict):
- ignored_users = frozenset(ignored_users_data.keys())
- return ignored_users
-
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
-
- Does not modify the `sync_result_builder`.
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1642,13 +1597,12 @@ class SyncHandler:
assert since_token
- # Get a list of membership change events that have happened to the user
- # requesting the sync.
- membership_changes = await self.store.get_membership_changes_for_user(
+ # Get a list of membership change events that have happened.
+ rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
- if membership_changes:
+ if rooms_changed:
return True
stream_id = since_token.room_key.stream
@@ -1660,25 +1614,7 @@ class SyncHandler:
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
- """Determine the changes in rooms to report to the user.
-
- Ideally, we want to report all events whose stream ordering `s` lies in the
- range `since_token < s <= now_token`, where the two tokens are read from the
- sync_result_builder.
-
- If there are too many events in that range to report, things get complicated.
- In this situation we return a truncated list of the most recent events, and
- indicate in the response that there is a "gap" of omitted events. Additionally:
-
- - we include a "state_delta", to describe the changes in state over the gap,
- - we include all membership events applying to the user making the request,
- even those in the gap.
-
- See the spec for the rationale:
- https://spec.matrix.org/v1.1/client-server-api/#syncing
-
- The sync_result_builder is not modified by this function.
- """
+ """Gets the the changes that have happened since the last sync."""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
@@ -1686,36 +1622,21 @@ class SyncHandler:
assert since_token
- # The spec
- # https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
- # notes that membership events need special consideration:
- #
- # > When a sync is limited, the server MUST return membership events for events
- # > in the gap (between since and the start of the returned timeline), regardless
- # > as to whether or not they are redundant.
- #
- # We fetch such events here, but we only seem to use them for categorising rooms
- # as newly joined, newly left, invited or knocked.
- # TODO: we've already called this function and ran this query in
- # _have_rooms_changed. We could keep the results in memory to avoid a
- # second query, at the cost of more complicated source code.
- membership_change_events = await self.store.get_membership_changes_for_user(
+ # Get a list of membership change events that have happened.
+ rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
- for event in membership_change_events:
+ for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
- newly_joined_rooms: List[str] = []
- newly_left_rooms: List[str] = []
- room_entries: List[RoomSyncResultBuilder] = []
- invited: List[InvitedSyncResult] = []
- knocked: List[KnockedSyncResult] = []
+ newly_joined_rooms = []
+ newly_left_rooms = []
+ room_entries = []
+ invited = []
+ knocked = []
for room_id, events in mem_change_events_by_room_id.items():
- # The body of this loop will add this room to at least one of the five lists
- # above. Things get messy if you've e.g. joined, left, joined then left the
- # room all in the same sync period.
logger.debug(
"Membership changes in %s: [%s]",
room_id,
@@ -1770,7 +1691,6 @@ class SyncHandler:
if not non_joins:
continue
- last_non_join = non_joins[-1]
# Check if we have left the room. This can either be because we were
# joined before *or* that we since joined and then left.
@@ -1792,18 +1712,18 @@ class SyncHandler:
newly_left_rooms.append(room_id)
# Only bother if we're still currently invited
- should_invite = last_non_join.membership == Membership.INVITE
+ should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite:
- if last_non_join.sender not in ignored_users:
- invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
+ if event.sender not in ignored_users:
+ invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if invite_room_sync:
invited.append(invite_room_sync)
# Only bother if our latest membership in the room is knock (and we haven't
# been accepted/rejected in the meantime).
- should_knock = last_non_join.membership == Membership.KNOCK
+ should_knock = non_joins[-1].membership == Membership.KNOCK
if should_knock:
- knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join)
+ knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
if knock_room_sync:
knocked.append(knock_room_sync)
@@ -1861,9 +1781,7 @@ class SyncHandler:
timeline_limit = sync_config.filter_collection.timeline_limit()
- # Get all events since the `from_key` in rooms we're currently joined to.
- # If there are too many, we get the most recent events only. This leaves
- # a "gap" in the timeline, as described by the spec for /sync.
+ # Get all events for rooms we're currently joined to.
room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
@@ -1924,10 +1842,6 @@ class SyncHandler:
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
- Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
-
- This function does not modify the sync_result_builder.
-
Args:
sync_result_builder
ignored_users: Set of users ignored by user.
@@ -1939,9 +1853,16 @@ class SyncHandler:
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
+ membership_list = (
+ Membership.INVITE,
+ Membership.KNOCK,
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
+ )
+
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
- user_id=user_id,
- membership_list=Membership.LIST,
+ user_id=user_id, membership_list=membership_list
)
room_entries = []
@@ -2291,7 +2212,8 @@ def _calculate_state(
# to only include membership events for the senders in the timeline.
# In practice, we can do this by removing them from the p_ids list,
# which is the list of relevant state we know we have already sent to the client.
- # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
+ # see https://github.com/matrix-org/synapse/pull/2970
+ # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
if lazy_load_members:
p_ids.difference_update(
@@ -2340,39 +2262,6 @@ class SyncResultBuilder:
groups: Optional[GroupsSyncResult] = None
to_device: List[JsonDict] = attr.Factory(list)
- def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
- """Work out which other users have joined or left rooms we are joined to.
-
- This data only is only useful for an incremental sync.
-
- The SyncResultBuilder is not modified by this function.
- """
- newly_joined_or_invited_or_knocked_users = set()
- newly_left_users = set()
- if self.since_token:
- for joined_sync in self.joined:
- it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.values()
- )
- for event in it:
- if event.type == EventTypes.Member:
- if (
- event.membership == Membership.JOIN
- or event.membership == Membership.INVITE
- or event.membership == Membership.KNOCK
- ):
- newly_joined_or_invited_or_knocked_users.add(
- event.state_key
- )
- else:
- prev_content = event.unsigned.get("prev_content", {})
- prev_membership = prev_content.get("membership", None)
- if prev_membership == Membership.JOIN:
- newly_left_users.add(event.state_key)
-
- newly_left_users -= newly_joined_or_invited_or_knocked_users
- return newly_joined_or_invited_or_knocked_users, newly_left_users
-
@attr.s(slots=True, auto_attribs=True)
class RoomSyncResultBuilder:
|