summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2023-01-10 12:43:28 +0000
committerDavid Robertson <davidr@element.io>2023-01-10 12:43:28 +0000
commit04aa6a970790543e773639c821d0628e5fadaf32 (patch)
tree48adfb5013d2a31253030a0f2508896e7c72b390 /synapse
parentMerge branch 'rei/dresync_exp' into matrix-org-hotfixes (diff)
parentUpdate changelog 2 (diff)
downloadsynapse-04aa6a970790543e773639c821d0628e5fadaf32.tar.xz
Merge remote-tracking branch 'origin/release-v1.75' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py2
-rw-r--r--synapse/api/filtering.py13
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/config/oidc.py6
-rw-r--r--synapse/handlers/account_data.py111
-rw-r--r--synapse/handlers/device.py9
-rw-r--r--synapse/handlers/oidc.py85
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/sync.py34
-rw-r--r--synapse/module_api/__init__.py40
-rw-r--r--synapse/push/clientformat.py5
-rw-r--r--synapse/replication/http/account_data.py92
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/rest/client/account.py5
-rw-r--r--synapse/rest/client/account_data.py115
-rw-r--r--synapse/rest/media/v1/oembed.py15
-rw-r--r--synapse/storage/_base.py17
-rw-r--r--synapse/storage/database.py33
-rw-r--r--synapse/storage/databases/main/account_data.py233
-rw-r--r--synapse/storage/databases/main/cache.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py7
-rw-r--r--synapse/storage/databases/main/devices.py11
-rw-r--r--synapse/storage/databases/main/events_worker.py15
-rw-r--r--synapse/storage/databases/main/presence.py8
-rw-r--r--synapse/storage/databases/main/push_rule.py7
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py7
-rw-r--r--synapse/storage/databases/main/tags.py8
-rw-r--r--synapse/util/macaroons.py7
30 files changed, 819 insertions, 93 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py

index d850e54e17..c463b60b26 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -1307,7 +1307,7 @@ def main() -> None: sqlite_config = { "name": "sqlite3", "args": { - "database": args.sqlite_database, + "database": "file:{}?mode=rw".format(args.sqlite_database), "cp_min": 1, "cp_max": 1, "check_same_thread": False, diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index a9888381b4..2b5af264b4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py
@@ -283,6 +283,9 @@ class FilterCollection: await self._room_filter.filter(events) ) + def blocks_all_rooms(self) -> bool: + return self._room_filter.filters_all_rooms() + def blocks_all_presence(self) -> bool: return ( self._presence_filter.filters_all_types() @@ -351,13 +354,13 @@ class Filter: self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: - return "*" in self.not_types + return self.types == [] or "*" in self.not_types def filters_all_senders(self) -> bool: - return "*" in self.not_senders + return self.senders == [] or "*" in self.not_senders def filters_all_rooms(self) -> bool: - return "*" in self.not_rooms + return self.rooms == [] or "*" in self.not_rooms def _check(self, event: FilterEvent) -> bool: """Checks whether the filter matches the given event. @@ -450,8 +453,8 @@ class Filter: if any(map(match_func, disallowed_values)): return False - # Other the event does not match at least one of the allowed values, - # reject it. + # Otherwise if the event does not match at least one of the allowed + # values, reject it. allowed_values = getattr(self, name) if allowed_values is not None: if not any(map(match_func, allowed_values)): diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 01ea2b4dab..bd265de536 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse from typing import ( Any, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 573fa0386f..0f3870bfe1 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -136,3 +136,6 @@ class ExperimentalConfig(Config): # Enable room version (and thus applicable push rules from MSC3931/3932) version_id = RoomVersions.MSC1767v10.identifier KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 + + # MSC3391: Removing account data. + self.msc3391_enabled = experimental.get("msc3391_enabled", False) diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 0bd83f4010..df8c422043 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py
@@ -117,6 +117,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { # to avoid importing authlib here. "enum": ["client_secret_basic", "client_secret_post", "none"], }, + "pkce_method": {"type": "string", "enum": ["auto", "always", "never"]}, "scopes": {"type": "array", "items": {"type": "string"}}, "authorization_endpoint": {"type": "string"}, "token_endpoint": {"type": "string"}, @@ -289,6 +290,7 @@ def _parse_oidc_config_dict( client_secret=oidc_config.get("client_secret"), client_secret_jwt_key=client_secret_jwt_key, client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + pkce_method=oidc_config.get("pkce_method", "auto"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), token_endpoint=oidc_config.get("token_endpoint"), @@ -357,6 +359,10 @@ class OidcProviderConfig: # 'none'. client_auth_method: str + # Whether to enable PKCE when exchanging the authorization & token. + # Valid values are 'auto', 'always', and 'never'. + pkce_method: str + # list of scopes to request scopes: Collection[str] diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index fc21d58001..aba7315cf7 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py
@@ -17,10 +17,12 @@ import random from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from synapse.replication.http.account_data import ( + ReplicationAddRoomAccountDataRestServlet, ReplicationAddTagRestServlet, + ReplicationAddUserAccountDataRestServlet, + ReplicationRemoveRoomAccountDataRestServlet, ReplicationRemoveTagRestServlet, - ReplicationRoomAccountDataRestServlet, - ReplicationUserAccountDataRestServlet, + ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource from synapse.types import JsonDict, StreamKeyType, UserID @@ -41,8 +43,18 @@ class AccountDataHandler: self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() - self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) - self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_user_data_client = ( + ReplicationAddUserAccountDataRestServlet.make_client(hs) + ) + self._remove_user_data_client = ( + ReplicationRemoveUserAccountDataRestServlet.make_client(hs) + ) + self._add_room_data_client = ( + ReplicationAddRoomAccountDataRestServlet.make_client(hs) + ) + self._remove_room_data_client = ( + ReplicationRemoveRoomAccountDataRestServlet.make_client(hs) + ) self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) self._account_data_writers = hs.config.worker.writers.account_data @@ -112,7 +124,7 @@ class AccountDataHandler: return max_stream_id else: - response = await self._room_data_client( + response = await self._add_room_data_client( instance_name=random.choice(self._account_data_writers), user_id=user_id, room_id=room_id, @@ -121,15 +133,59 @@ class AccountDataHandler: ) return response["max_stream_id"] + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """ + Deletes the room account data for the given user and account data type. + + "Deleting" account data merely means setting the content of the account data + to an empty JSON object: {}. + + Args: + user_id: The user ID to remove room account data for. + room_id: The room ID to target. + account_data_type: The account data type to remove. + + Returns: + The maximum stream ID, or None if the room account data item did not exist. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + if max_stream_id is None: + # The referenced account data did not exist, so no delete occurred. + return None + + self._notifier.on_new_event( + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] + ) + + # Notify Synapse modules that the content of the type has changed to an + # empty dictionary. + await self._notify_modules(user_id, room_id, account_data_type, {}) + + return max_stream_id + else: + response = await self._remove_room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content={}, + ) + return response["max_stream_id"] + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: """Add some global account_data for a user. Args: - user_id: The user to add a tag for. + user_id: The user to add some account data for. account_data_type: The type of account_data to add. - content: A json object to associate with the tag. + content: The content json dictionary. Returns: The maximum stream ID. @@ -148,7 +204,7 @@ class AccountDataHandler: return max_stream_id else: - response = await self._user_data_client( + response = await self._add_user_data_client( instance_name=random.choice(self._account_data_writers), user_id=user_id, account_data_type=account_data_type, @@ -156,6 +212,45 @@ class AccountDataHandler: ) return response["max_stream_id"] + async def remove_account_data_for_user( + self, user_id: str, account_data_type: str + ) -> Optional[int]: + """Removes a piece of global account_data for a user. + + Args: + user_id: The user to remove account data for. + account_data_type: The type of account_data to remove. + + Returns: + The maximum stream ID, or None if the room account data item did not exist. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_account_data_for_user( + user_id, account_data_type + ) + if max_stream_id is None: + # The referenced account data did not exist, so no delete occurred. + return None + + self._notifier.on_new_event( + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] + ) + + # Notify Synapse modules that the content of the type has changed to an + # empty dictionary. + await self._notify_modules(user_id, None, account_data_type, {}) + + return max_stream_id + else: + response = await self._remove_user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content={}, + ) + return response["max_stream_id"] + async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 68a0c8ccb4..89864e1119 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -919,6 +919,11 @@ class DeviceListWorkerUpdater: """ # mark_failed_as_stale is not sent. Ensure this doesn't break expectations. assert mark_failed_as_stale + + if not user_ids: + # Shortcut empty requests + return {} + try: return await self._multi_user_device_resync_client(user_ids=user_ids) except SynapseError as err: @@ -946,6 +951,8 @@ class DeviceListWorkerUpdater: A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. """ return (await self.multi_user_device_resync([user_id]))[user_id] @@ -1250,6 +1257,8 @@ class DeviceListUpdater(DeviceListWorkerUpdater): - A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. - True iff the resync failed and the device list should be marked as stale. """ logger.debug("Attempting to resync the device list for %s", user_id) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 03de6a4ba6..0fc829acf7 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py
@@ -36,6 +36,7 @@ from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template @@ -475,6 +476,16 @@ class OidcProvider: ) ) + # If PKCE support is advertised ensure the wanted method is available. + if m.get("code_challenge_methods_supported") is not None: + m.validate_code_challenge_methods_supported() + if "S256" not in m["code_challenge_methods_supported"]: + raise ValueError( + '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format( + supported=m["code_challenge_methods_supported"], + ) + ) + if m.get("response_types_supported") is not None: m.validate_response_types_supported() @@ -602,6 +613,11 @@ class OidcProvider: if self._config.jwks_uri: metadata["jwks_uri"] = self._config.jwks_uri + if self._config.pkce_method == "always": + metadata["code_challenge_methods_supported"] = ["S256"] + elif self._config.pkce_method == "never": + metadata.pop("code_challenge_methods_supported", None) + self._validate_metadata(metadata) return metadata @@ -653,7 +669,7 @@ class OidcProvider: return jwk_set - async def _exchange_code(self, code: str) -> Token: + async def _exchange_code(self, code: str, code_verifier: str) -> Token: """Exchange an authorization code for a token. This calls the ``token_endpoint`` with the authorization code we @@ -666,6 +682,7 @@ class OidcProvider: Args: code: The authorization code we got from the callback. + code_verifier: The PKCE code verifier to send, blank if unused. Returns: A dict containing various tokens. @@ -696,6 +713,8 @@ class OidcProvider: "code": code, "redirect_uri": self._callback_url, } + if code_verifier: + args["code_verifier"] = code_verifier body = urlencode(args, True) # Fill the body/headers with credentials @@ -914,11 +933,14 @@ class OidcProvider: - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string + - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported) - In addition generating a redirect URL, we are setting a cookie with - a signed macaroon token containing the state, the nonce and the - client_redirect_url params. Those are then checked when the client - comes back from the provider. + In addition to generating a redirect URL, we are setting a cookie with + a signed macaroon token containing the state, the nonce, the + client_redirect_url, and (optionally) the code_verifier params. The state, + nonce, and client_redirect_url are then checked when the client comes back + from the provider. The code_verifier is passed back to the server during + the token exchange and compared to the code_challenge sent in this request. Args: request: the incoming request from the browser. @@ -935,10 +957,25 @@ class OidcProvider: state = generate_token() nonce = generate_token() + code_verifier = "" if not client_redirect_url: client_redirect_url = b"" + metadata = await self.load_metadata() + + # Automatically enable PKCE if it is supported. + extra_grant_values = {} + if metadata.get("code_challenge_methods_supported"): + code_verifier = generate_token(48) + + # Note that we verified the server supports S256 earlier (in + # OidcProvider._validate_metadata). + extra_grant_values = { + "code_challenge_method": "S256", + "code_challenge": create_s256_code_challenge(code_verifier), + } + cookie = self._macaroon_generaton.generate_oidc_session_token( state=state, session_data=OidcSessionData( @@ -946,6 +983,7 @@ class OidcProvider: nonce=nonce, client_redirect_url=client_redirect_url.decode(), ui_auth_session_id=ui_auth_session_id or "", + code_verifier=code_verifier, ), ) @@ -966,7 +1004,6 @@ class OidcProvider: ) ) - metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") return prepare_grant_uri( authorization_endpoint, @@ -976,6 +1013,7 @@ class OidcProvider: scope=self._scopes, state=state, nonce=nonce, + **extra_grant_values, ) async def handle_oidc_callback( @@ -1003,7 +1041,9 @@ class OidcProvider: # Exchange the code with the provider try: logger.debug("Exchanging OAuth2 code for a token") - token = await self._exchange_code(code) + token = await self._exchange_code( + code, code_verifier=session_data.code_verifier + ) except OidcError as e: logger.warning("Could not exchange OAuth2 code: %s", e) self._sso_handler.render_error(request, e.error, e.error_description) @@ -1520,8 +1560,8 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: - subject_claim: str - picture_claim: str + subject_template: Template + picture_template: Template localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1540,8 +1580,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: - subject_claim = config.get("subject_claim", "sub") - picture_claim = config.get("picture_claim", "picture") + def parse_template_config_with_claim( + option_name: str, default_claim: str + ) -> Template: + template_name = f"{option_name}_template" + template = config.get(template_name) + if not template: + # Convert the legacy subject_claim into a template. + claim = config.get(f"{option_name}_claim", default_claim) + template = "{{ user.%s }}" % (claim,) + + try: + return env.from_string(template) + except Exception as e: + raise ConfigError("invalid jinja template", path=[template_name]) from e + + subject_template = parse_template_config_with_claim("subject", "sub") + picture_template = parse_template_config_with_claim("picture", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1574,8 +1629,8 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): raise ConfigError("must be a bool", path=["confirm_localpart"]) return JinjaOidcMappingConfig( - subject_claim=subject_claim, - picture_claim=picture_claim, + subject_template=subject_template, + picture_template=picture_template, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1584,7 +1639,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): ) def get_remote_user_id(self, userinfo: UserInfo) -> str: - return userinfo[self._config.subject_claim] + return self._config.subject_template.render(user=userinfo).strip() async def map_user_attributes( self, userinfo: UserInfo, token: Token, failures: int @@ -1615,7 +1670,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) - picture = userinfo.get("picture") + picture = self._config.picture_template.render(user=userinfo).strip() return UserAttributeDict( localpart=localpart, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 33115ce488..40f4635c4e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -275,7 +275,7 @@ class SearchHandler: ) room_ids = {r.room_id for r in rooms} - # If doing a subset of all rooms seearch, check if any of the rooms + # If doing a subset of all rooms search, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: historical_room_ids: List[str] = [] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7d6a653747..6942e06c77 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.handlers.relations import BundledAggregations +from synapse.logging import issue9533_logger from synapse.logging.context import current_context from synapse.logging.opentracing import ( SynapseTags, @@ -1402,11 +1403,14 @@ class SyncHandler: logger.debug("Fetching room data") - res = await self._generate_sync_entry_for_rooms( + ( + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, + newly_left_rooms, + newly_left_users, + ) = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res - _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( since_token is None and sync_config.filter_collection.blocks_all_presence() @@ -1623,13 +1627,18 @@ class SyncHandler: } ) - logger.debug( - "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), - since_stream_id, - stream_id, - now_token.to_device_key, - ) + if messages and issue9533_logger.isEnabledFor(logging.DEBUG): + issue9533_logger.debug( + "Returning to-device messages with stream_ids (%d, %d]; now: %d;" + " msgids: %s", + since_stream_id, + stream_id, + now_token.to_device_key, + [ + message["content"].get(EventContentFields.TO_DEVICE_MSGID) + for message in messages + ], + ) sync_result_builder.now_token = now_token.copy_and_replace( StreamKeyType.TO_DEVICE, stream_id ) @@ -1783,6 +1792,11 @@ class SyncHandler: - newly_left_rooms - newly_left_users """ + + # If the request doesn't care about rooms then nothing to do! + if sync_result_builder.sync_config.filter_collection.blocks_all_rooms(): + return set(), set(), set(), set() + since_token = sync_result_builder.since_token # 1. Start by fetching all ephemeral events in rooms we've joined (if required). diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 0092a03c59..6f4a934b05 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -18,6 +18,7 @@ from typing import ( TYPE_CHECKING, Any, Callable, + Collection, Dict, Generator, Iterable, @@ -126,7 +127,7 @@ from synapse.types import ( from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import CachedFunction, cached +from synapse.util.caches.descriptors import CachedFunction, cached as _cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -136,6 +137,7 @@ if TYPE_CHECKING: T = TypeVar("T") P = ParamSpec("P") +F = TypeVar("F", bound=Callable[..., Any]) """ This package defines the 'stable' API which can be used by extension modules which @@ -185,6 +187,42 @@ class UserIpAndAgent: last_seen: int +def cached( + *, + max_entries: int = 1000, + num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, +) -> Callable[[F], CachedFunction[F]]: + """Returns a decorator that applies a memoizing cache around the function. This + decorator behaves similarly to functools.lru_cache. + + Example: + + @cached() + def foo('a', 'b'): + ... + + Added in Synapse v1.74.0. + + Args: + max_entries: The maximum number of entries in the cache. If the cache is full + and a new entry is added, the least recently accessed entry will be evicted + from the cache. + num_args: The number of positional arguments (excluding `self`) to use as cache + keys. Defaults to all named args of the function. + uncached_args: A list of argument names to not use as the cache key. (`self` is + always ignored.) Cannot be used with num_args. + + Returns: + A decorator that applies a memoizing cache around the function. + """ + return _cached( + max_entries=max_entries, + num_args=num_args, + uncached_args=uncached_args, + ) + + class ModuleApi: """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 622a1e35c5..bb76c169c6 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py
@@ -26,10 +26,7 @@ def format_push_rules_for_user( """Converts a list of rawrules and a enabled map into nested dictionaries to match the Matrix client-server format for push rules""" - rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = { - "global": {}, - "device": {}, - } + rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {"global": {}} rules["global"] = _add_empty_priority_class_arrays(rules["global"]) diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 310f609153..0edc95977b 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py
@@ -28,7 +28,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): +class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint): """Add user account data on the appropriate account data worker. Request format: @@ -49,7 +49,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -73,7 +72,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): +class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint): + """Remove user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "remove_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, account_data_type: str + ) -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: + max_stream_id = await self.handler.remove_account_data_for_user( + user_id, account_data_type + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint): """Add room account data on the appropriate account data worker. Request format: @@ -94,7 +131,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -118,6 +154,44 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} +class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint): + """Remove room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "remove_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, room_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: + max_stream_id = await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + + return 200, {"max_stream_id": max_stream_id} + + class ReplicationAddTagRestServlet(ReplicationEndpoint): """Add tag on the appropriate account data worker. @@ -139,7 +213,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload( # type: ignore[override] @@ -186,7 +259,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): super().__init__(hs) self.handler = hs.get_account_data_handler() - self.clock = hs.get_clock() @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] @@ -206,7 +278,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - ReplicationUserAccountDataRestServlet(hs).register(http_server) - ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddUserAccountDataRestServlet(hs).register(http_server) + ReplicationAddRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) ReplicationRemoveTagRestServlet(hs).register(http_server) + + if hs.config.experimental.msc3391_enabled: + ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server) + ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 658d89210d..b5e40da533 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -152,6 +152,9 @@ class ReplicationDataHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index c1781bc814..232f3a976d 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -338,6 +338,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + if not self.hs.config.registration.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + if not self.config.email.can_verify_email: logger.warning( "Adding emails have been disabled due to lack of an email config" diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index f13970b898..e805196fec 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py
@@ -41,6 +41,7 @@ class AccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() @@ -54,6 +55,16 @@ class AccountDataServlet(RestServlet): body = parse_json_object_from_request(request) + # If experimental support for MSC3391 is enabled, then providing an empty dict + # as the value for an account data type should be functionally equivalent to + # calling the DELETE method on the same type. + if self._hs.config.experimental.msc3391_enabled: + if body == {}: + await self.handler.remove_account_data_for_user( + user_id, account_data_type + ) + return 200, {} + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -72,9 +83,48 @@ class AccountDataServlet(RestServlet): if event is None: raise NotFoundError("Account data not found") + # If experimental support for MSC3391 is enabled, then this endpoint should + # return a 404 if the content for an account data type is an empty dict. + if self._hs.config.experimental.msc3391_enabled and event == {}: + raise NotFoundError("Account data not found") + return 200, event +class UnstableAccountDataServlet(RestServlet): + """ + Contains an unstable endpoint for removing user account data, as specified by + MSC3391. If that MSC is accepted, this code should have unstable prefixes removed + and become incorporated into AccountDataServlet above. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3391/user/(?P<user_id>[^/]*)" + "/account_data/(?P<account_data_type>[^/]*)", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_DELETE( + self, + request: SynapseRequest, + user_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot delete account data for other users.") + + await self.handler.remove_account_data_for_user(user_id, account_data_type) + + return 200, {} + + class RoomAccountDataServlet(RestServlet): """ PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 @@ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() + self._hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main self.handler = hs.get_account_data_handler() @@ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet): Codes.BAD_JSON, ) + # If experimental support for MSC3391 is enabled, then providing an empty dict + # as the value for an account data type should be functionally equivalent to + # calling the DELETE method on the same type. + if self._hs.config.experimental.msc3391_enabled: + if body == {}: + await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + return 200, {} + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet): if event is None: raise NotFoundError("Room account data not found") + # If experimental support for MSC3391 is enabled, then this endpoint should + # return a 404 if the content for an account data type is an empty dict. + if self._hs.config.experimental.msc3391_enabled and event == {}: + raise NotFoundError("Room account data not found") + return 200, event +class UnstableRoomAccountDataServlet(RestServlet): + """ + Contains an unstable endpoint for removing room account data, as specified by + MSC3391. If that MSC is accepted, this code should have unstable prefixes removed + and become incorporated into RoomAccountDataServlet above. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3391/user/(?P<user_id>[^/]*)" + "/rooms/(?P<room_id>[^/]*)" + "/account_data/(?P<account_data_type>[^/]*)", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_DELETE( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot delete account data for other users.") + + if not RoomID.is_valid(room_id): + raise SynapseError( + 400, + f"{room_id} is not a valid room ID", + Codes.INVALID_PARAM, + ) + + await self.handler.remove_account_data_for_room( + user_id, room_id, account_data_type + ) + + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) + + if hs.config.experimental.msc3391_enabled: + UnstableAccountDataServlet(hs).register(http_server) + UnstableRoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 827afd868d..a3738a6250 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py
@@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import html import logging import urllib.parse from typing import TYPE_CHECKING, List, Optional @@ -161,7 +162,9 @@ class OEmbedProvider: title = oembed.get("title") if title and isinstance(title, str): - open_graph_response["og:title"] = title + # A common WordPress plug-in seems to incorrectly escape entities + # in the oEmbed response. + open_graph_response["og:title"] = html.unescape(title) author_name = oembed.get("author_name") if not isinstance(author_name, str): @@ -180,9 +183,9 @@ class OEmbedProvider: # Process each type separately. oembed_type = oembed.get("type") if oembed_type == "rich": - html = oembed.get("html") - if isinstance(html, str): - calc_description_and_urls(open_graph_response, html) + html_str = oembed.get("html") + if isinstance(html_str, str): + calc_description_and_urls(open_graph_response, html_str) elif oembed_type == "photo": # If this is a photo, use the full image, not the thumbnail. @@ -192,8 +195,8 @@ class OEmbedProvider: elif oembed_type == "video": open_graph_response["og:type"] = "video.other" - html = oembed.get("html") - if html and isinstance(html, str): + html_str = oembed.get("html") + if html_str and isinstance(html_str, str): calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 69abf6fa87..41d9111019 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta): token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0b29e67b94..88479a16db 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -1762,7 +1762,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics Returns: - A list of dictionaries. + A list of dictionaries, one per result row, each a mapping between the + column names from `retcols` and that column's value for the row. """ return await self.runInteraction( desc, @@ -1791,6 +1792,10 @@ class DatabasePool: column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + + Returns: + A list of dictionaries, one per result row, each a mapping between the + column names from `retcols` and that column's value for the row. """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1898,6 +1903,19 @@ class DatabasePool: updatevalues: Dict[str, Any], desc: str, ) -> int: + """ + Update rows in the given database table. + If the given keyvalues don't match anything, nothing will be updated. + + Args: + table: The database table to update. + keyvalues: A mapping of column name to value to match rows on. + updatevalues: A mapping of column name to value to replace in any matched rows. + desc: description of the transaction, for logging and metrics. + + Returns: + The number of rows that were updated. Will be 0 if no matching rows were found. + """ return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1909,6 +1927,19 @@ class DatabasePool: keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], ) -> int: + """ + Update rows in the given database table. + If the given keyvalues don't match anything, nothing will be updated. + + Args: + txn: The database transaction object. + table: The database table to update. + keyvalues: A mapping of column name to value to match rows on. + updatevalues: A mapping of column name to value to replace in any matched rows. + + Returns: + The number of rows that were updated. Will be 0 if no matching rows were found. + """ if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) else: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 07908c41d9..86032897f5 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - """Get all the client account_data for a user. + """ + Get all the client account_data for a user. + + If experimental MSC3391 support is enabled, any entries with an empty + content body are excluded; as this means they have been deleted. Args: user_id: The user to get the account_data for. @@ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - rows = self.db_pool.simple_select_list_txn( - txn, - "account_data", - {"user_id": user_id}, - ["account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT account_data_type, content FROM account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } - rows = self.db_pool.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id}, - ["room_id", "account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT room_id, account_data_type, content FROM room_account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -411,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -429,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -469,6 +500,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return self._account_data_id_gen.get_current_token() + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """Delete the room account data for the user of a given type. + + Args: + user_id: The user to remove account_data for. + room_id: The room ID to scope the request to. + account_data_type: The account data type to delete. + + Returns: + The maximum stream position, or None if there was no matching room account + data to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_room_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in room_account_data had its content set to '{}', + otherwise False. This informs callers of whether there actually was an + existing room account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE room_account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND room_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute( + sql, + (next_id, user_id, room_id, account_data_type), + ) + # Return true if any rows were updated. + return txn.rowcount != 0 + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_room", + _remove_account_data_for_room_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -569,6 +666,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + async def remove_account_data_for_user( + self, + user_id: str, + account_data_type: str, + ) -> Optional[int]: + """ + Delete a single piece of user account data by type. + + A "delete" is performed by updating a potentially existing row in the + "account_data" database table for (user_id, account_data_type) and + setting its content to "{}". + + Args: + user_id: The user ID to modify the account data of. + account_data_type: The type to remove. + + Returns: + The maximum stream position, or None if there was no matching account data + to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_user_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in account_data had its content set to '{}', otherwise + False. This informs callers of whether there actually was an existing + account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute(sql, (next_id, user_id, account_data_type)) + if txn.rowcount == 0: + # We didn't update any rows. This means that there was no matching room + # account data entry to delete in the first place. + return False + + # Ignored users get denormalized into a separate table as an optimisation. + if account_data_type == AccountDataTypes.IGNORED_USER_LIST: + # If this method was called with the ignored users account data type, we + # simply delete all ignored users. + + # First pull all the users that this user ignores. + previously_ignored_users = set( + self.db_pool.simple_select_onecol_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + ) + ) + + # Then delete them from the database. + self.db_pool.simple_delete_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + ) + + # Invalidate the cache for ignored users which were removed. + for ignored_user_id in previously_ignored_users: + self._invalidate_cache_and_stream( + txn, self.ignored_by, (ignored_user_id,) + ) + + # Invalidate for this user the cache tracking ignored users. + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + + return True + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_user", + _remove_account_data_for_user_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def purge_account_data_for_user(self, user_id: str) -> None: """ Removes ALL the account data for a user. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a58668a380..2179a8bf59 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 50899b2949..2440ac03f7 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a921332cb0..b067664473 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f80b494edb..90aa4e01bf 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(instance_name, token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(instance_name, -token) - elif stream_name == UnPartialStatedEventStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) + super().process_replication_position(stream_name, instance_name, token) + async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 9769a18a9d..7b60815043 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d4c64c46ad..d4e4b777da 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -154,6 +154,13 @@ class PushRulesWorkerStore( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 40fd781a6a..7f24a3b6ec 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e06725f69c..86f5bce5f0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b0f5de67a3..e23c927e02 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 5df03d3ddc..644c341e8c 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py
@@ -110,6 +110,9 @@ class OidcSessionData: ui_auth_session_id: str """The session ID of the ongoing UI Auth ("" if this is a login)""" + code_verifier: str + """The random string used in the RFC7636 code challenge ("" if PKCE is not being used).""" + class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): @@ -187,6 +190,7 @@ class MacaroonGenerator: macaroon.add_first_party_caveat( f"ui_auth_session_id = {session_data.ui_auth_session_id}" ) + macaroon.add_first_party_caveat(f"code_verifier = {session_data.code_verifier}") macaroon.add_first_party_caveat(f"time < {expiry}") return macaroon.serialize() @@ -278,6 +282,7 @@ class MacaroonGenerator: v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + v.satisfy_general(lambda c: c.startswith("code_verifier = ")) satisfy_expiry(v, self._clock.time_msec) v.verify(macaroon, self._secret_key) @@ -287,11 +292,13 @@ class MacaroonGenerator: idp_id = get_value_from_macaroon(macaroon, "idp_id") client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") + code_verifier = get_value_from_macaroon(macaroon, "code_verifier") return OidcSessionData( nonce=nonce, idp_id=idp_id, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, + code_verifier=code_verifier, ) def _generate_base_macaroon(self, type: MacaroonType) -> pymacaroons.Macaroon: