diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 26834a437e..543bba27c2 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -166,22 +166,6 @@ IGNORED_TABLES = {
"ui_auth_sessions",
"ui_auth_sessions_credentials",
"ui_auth_sessions_ips",
- # Groups/communities is no longer supported.
- "group_attestations_remote",
- "group_attestations_renewals",
- "group_invites",
- "group_roles",
- "group_room_categories",
- "group_rooms",
- "group_summary_roles",
- "group_summary_room_categories",
- "group_summary_rooms",
- "group_summary_users",
- "group_users",
- "groups",
- "local_group_membership",
- "local_group_updates",
- "remote_profile_cache",
}
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 54d13026c9..f43965c1c8 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -27,6 +27,33 @@ class Ratelimiter:
"""
Ratelimit actions marked by arbitrary keys.
+ (Note that the source code speaks of "actions" and "burst_count" rather than
+ "tokens" and a "bucket_size".)
+
+ This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
+ containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
+ permitted requests for that key. Each bucket starts empty, and gradually leaks
+ tokens at a rate of `rate_hz`.
+
+ Upon an incoming request, we must determine:
+ - the key that this request falls under (which bucket to inspect), and
+ - the cost C of this request in tokens.
+ Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
+ the request is permitted and `cost` tokens are added to the bucket.
+ Otherwise the request is denied, and the bucket continues to hold T tokens.
+
+ This means that the limiter enforces an average request frequency of `rate_hz`,
+ while accumulating a buffer of up to `burst_count` requests which can be consumed
+ instantaneously.
+
+ The tricky bit is the leaking. We do not want to have a periodic process which
+ leaks every bucket! Instead, we track
+ - the time point when the bucket was last completely empty, and
+ - how many tokens have added to the bucket permitted since then.
+ Then for each incoming request, we can calculate how many tokens have leaked
+ since this time point, and use that to decide if we should accept or reject the
+ request.
+
Args:
clock: A homeserver clock, for retrieving the current time
rate_hz: The long term number of actions that can be performed in a second.
@@ -41,14 +68,30 @@ class Ratelimiter:
self.burst_count = burst_count
self.store = store
- # A ordered dictionary keeping track of actions, when they were last
- # performed and how often. Each entry is a mapping from a key of arbitrary type
- # to a tuple representing:
- # * How many times an action has occurred since a point in time
- # * The point in time
- # * The rate_hz of this particular entry. This can vary per request
+ # An ordered dictionary representing the token buckets tracked by this rate
+ # limiter. Each entry maps a key of arbitrary type to a tuple representing:
+ # * The number of tokens currently in the bucket,
+ # * The time point when the bucket was last completely empty, and
+ # * The rate_hz (leak rate) of this particular bucket.
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
+ def _get_key(
+ self, requester: Optional[Requester], key: Optional[Hashable]
+ ) -> Hashable:
+ """Use the requester's MXID as a fallback key if no key is provided."""
+ if key is None:
+ if not requester:
+ raise ValueError("Must supply at least one of `requester` or `key`")
+
+ key = requester.user.to_string()
+ return key
+
+ def _get_action_counts(
+ self, key: Hashable, time_now_s: float
+ ) -> Tuple[float, float, float]:
+ """Retrieve the action counts, with a fallback representing an empty bucket."""
+ return self.actions.get(key, (0.0, time_now_s, 0.0))
+
async def can_do_action(
self,
requester: Optional[Requester],
@@ -88,11 +131,7 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
- if key is None:
- if not requester:
- raise ValueError("Must supply at least one of `requester` or `key`")
-
- key = requester.user.to_string()
+ key = self._get_key(requester, key)
if requester:
# Disable rate limiting of users belonging to any AS that is configured
@@ -121,7 +160,7 @@ class Ratelimiter:
self._prune_message_counts(time_now_s)
# Check if there is an existing count entry for this key
- action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+ action_count, time_start, _ = self._get_action_counts(key, time_now_s)
# Check whether performing another action is allowed
time_delta = time_now_s - time_start
@@ -164,6 +203,37 @@ class Ratelimiter:
return allowed, time_allowed
+ def record_action(
+ self,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
+ n_actions: int = 1,
+ _time_now_s: Optional[float] = None,
+ ) -> None:
+ """Record that an action(s) took place, even if they violate the rate limit.
+
+ This is useful for tracking the frequency of events that happen across
+ federation which we still want to impose local rate limits on. For instance, if
+ we are alice.com monitoring a particular room, we cannot prevent bob.com
+ from joining users to that room. However, we can track the number of recent
+ joins in the room and refuse to serve new joins ourselves if there have been too
+ many in the room across both homeservers.
+
+ Args:
+ requester: The requester that is doing the action, if any.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
+ n_actions: The number of times the user wants to do this action. If the user
+ cannot do all of the actions, the user's action count is not incremented
+ at all.
+ _time_now_s: The current time. Optional, defaults to the current time according
+ to self.clock. Only used by tests.
+ """
+ key = self._get_key(requester, key)
+ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
+ action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
+ self.actions[key] = (action_count + n_actions, time_start, rate_hz)
+
def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 3f85d61b46..00e81b3afc 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -84,6 +84,8 @@ class RoomVersion:
# MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
# knocks and restricted join rules into the same join condition.
msc3787_knock_restricted_join_rule: bool
+ # MSC3667: Enforce integer power levels
+ msc3667_int_only_power_levels: bool
class RoomVersions:
@@ -103,6 +105,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V2 = RoomVersion(
"2",
@@ -120,6 +123,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V3 = RoomVersion(
"3",
@@ -137,6 +141,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V4 = RoomVersion(
"4",
@@ -154,6 +159,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V5 = RoomVersion(
"5",
@@ -171,6 +177,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V6 = RoomVersion(
"6",
@@ -188,6 +195,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -205,6 +213,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V7 = RoomVersion(
"7",
@@ -222,6 +231,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V8 = RoomVersion(
"8",
@@ -239,6 +249,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
V9 = RoomVersion(
"9",
@@ -256,6 +267,7 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
MSC2716v3 = RoomVersion(
"org.matrix.msc2716v3",
@@ -273,6 +285,7 @@ class RoomVersions:
msc2716_historical=True,
msc2716_redactions=True,
msc3787_knock_restricted_join_rule=False,
+ msc3667_int_only_power_levels=False,
)
MSC3787 = RoomVersion(
"org.matrix.msc3787",
@@ -290,6 +303,25 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=True,
+ msc3667_int_only_power_levels=False,
+ )
+ V10 = RoomVersion(
+ "10",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=True,
+ msc3375_redaction_rules=True,
+ msc2403_knocking=True,
+ msc2716_historical=False,
+ msc2716_redactions=False,
+ msc3787_knock_restricted_join_rule=True,
+ msc3667_int_only_power_levels=True,
)
@@ -308,6 +340,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V9,
RoomVersions.MSC2716v3,
RoomVersions.MSC3787,
+ RoomVersions.V10,
)
}
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 745e704141..6bafa7d3f3 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -44,7 +44,6 @@ from synapse.app._base import (
register_start,
)
from synapse.config._base import ConfigError, format_config_error
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig
from synapse.federation.transport.server import TransportLayerServer
@@ -202,7 +201,7 @@ class SynapseHomeServer(HomeServer):
}
)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
from synapse.rest.synapse.client.password_reset import (
PasswordResetSubmitTokenResource,
)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index df1c214462..0963fb3bb4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -53,6 +53,18 @@ sent_events_counter = Counter(
"synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
)
+sent_ephemeral_counter = Counter(
+ "synapse_appservice_api_sent_ephemeral",
+ "Number of ephemeral events sent to the AS",
+ ["service"],
+)
+
+sent_todevice_counter = Counter(
+ "synapse_appservice_api_sent_todevice",
+ "Number of todevice messages sent to the AS",
+ ["service"],
+)
+
HOUR_IN_MS = 60 * 60 * 1000
@@ -310,6 +322,8 @@ class ApplicationServiceApi(SimpleHttpClient):
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(serialized_events))
+ sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
+ sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
return True
except CodeMessageException as e:
logger.warning(
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 6e11fbdb9a..3ead80d985 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -18,7 +18,6 @@
import email.utils
import logging
import os
-from enum import Enum
from typing import Any
import attr
@@ -131,41 +130,22 @@ class EmailConfig(Config):
self.email_enable_notifs = email_config.get("enable_notifs", False)
- self.threepid_behaviour_email = (
- # Have Synapse handle the email sending if account_threepid_delegates.email
- # is not defined
- # msisdn is currently always remote while Synapse does not support any method of
- # sending SMS messages
- ThreepidBehaviour.REMOTE
- if self.root.registration.account_threepid_delegate_email
- else ThreepidBehaviour.LOCAL
- )
-
if config.get("trust_identity_server_for_password_resets"):
raise ConfigError(
'The config option "trust_identity_server_for_password_resets" '
- 'has been replaced by "account_threepid_delegate". '
- "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for "
- "details and update your config file."
+ "is no longer supported. Please remove it from the config file."
)
- self.local_threepid_handling_disabled_due_to_email_config = False
- if (
- self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
- and email_config == {}
- ):
- # We cannot warn the user this has happened here
- # Instead do so when a user attempts to reset their password
- self.local_threepid_handling_disabled_due_to_email_config = True
-
- self.threepid_behaviour_email = ThreepidBehaviour.OFF
+ # If we have email config settings, assume that we can verify ownership of
+ # email addresses.
+ self.can_verify_email = email_config != {}
# Get lifetime of a validation token in milliseconds
self.email_validation_token_lifetime = self.parse_duration(
email_config.get("validation_token_lifetime", "1h")
)
- if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.can_verify_email:
missing = []
if not self.email_notif_from:
missing.append("email.notif_from")
@@ -356,18 +336,3 @@ class EmailConfig(Config):
"Config option email.invite_client_location must be a http or https URL",
path=("email", "invite_client_location"),
)
-
-
-class ThreepidBehaviour(Enum):
- """
- Enum to define the behaviour of Synapse with regards to when it contacts an identity
- server for 3pid registration and password resets
-
- REMOTE = use an external server to send tokens
- LOCAL = send tokens ourselves
- OFF = disable registration via 3pid and password resets
- """
-
- REMOTE = "remote"
- LOCAL = "local"
- OFF = "off"
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4fc1784efe..5a91917b4a 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -112,6 +112,13 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 10},
)
+ # Track the rate of joins to a given room. If there are too many, temporarily
+ # prevent local joins and remote joins via this server.
+ self.rc_joins_per_room = RateLimitConfig(
+ config.get("rc_joins_per_room", {}),
+ defaults={"per_second": 1, "burst_count": 10},
+ )
+
# Ratelimit cross-user key requests:
# * For local requests this is keyed by the sending device.
# * For requests received over federation this is keyed by the origin.
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index fcf99be092..685a0423c5 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -20,6 +20,13 @@ from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string_with_symbols, strtobool
+NO_EMAIL_DELEGATE_ERROR = """\
+Delegation of email verification to an identity server is no longer supported. To
+continue to allow users to add email addresses to their accounts, and use them for
+password resets, configure Synapse with an SMTP server via the `email` setting, and
+remove `account_threepid_delegates.email`.
+"""
+
class RegistrationConfig(Config):
section = "registration"
@@ -51,7 +58,9 @@ class RegistrationConfig(Config):
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
account_threepid_delegates = config.get("account_threepid_delegates") or {}
- self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if "email" in account_threepid_delegates:
+ raise ConfigError(NO_EMAIL_DELEGATE_ERROR)
+ # self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 3c69dd325f..1033496bb4 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -42,6 +42,18 @@ THUMBNAIL_SIZE_YAML = """\
# method: %(method)s
"""
+# A map from the given media type to the type of thumbnail we should generate
+# for it.
+THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = {
+ "image/jpeg": "jpeg",
+ "image/jpg": "jpeg",
+ "image/webp": "jpeg",
+ # Thumbnails can only be jpeg or png. We choose png thumbnails for gif
+ # because it can have transparency.
+ "image/gif": "png",
+ "image/png": "png",
+}
+
HTTP_PROXY_SET_WARNING = """\
The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
@@ -79,13 +91,22 @@ def parse_thumbnail_requirements(
width = size["width"]
height = size["height"]
method = size["method"]
- jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
- png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
- requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
- requirements.setdefault("image/jpg", []).append(jpeg_thumbnail)
- requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
- requirements.setdefault("image/gif", []).append(png_thumbnail)
- requirements.setdefault("image/png", []).append(png_thumbnail)
+
+ for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items():
+ requirement = requirements.setdefault(format, [])
+ if thumbnail_format == "jpeg":
+ requirement.append(
+ ThumbnailRequirement(width, height, method, "image/jpeg")
+ )
+ elif thumbnail_format == "png":
+ requirement.append(
+ ThumbnailRequirement(width, height, method, "image/png")
+ )
+ else:
+ raise Exception(
+ "Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!"
+ % (format, thumbnail_format)
+ )
return {
media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 0fc2c4b27e..965cb265da 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -740,6 +740,32 @@ def _check_power_levels(
except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
+ # Reject events with stringy power levels if required by room version
+ if (
+ event.type == EventTypes.PowerLevels
+ and room_version_obj.msc3667_int_only_power_levels
+ ):
+ for k, v in event.content.items():
+ if k in {
+ "users_default",
+ "events_default",
+ "state_default",
+ "ban",
+ "redact",
+ "kick",
+ "invite",
+ }:
+ if not isinstance(v, int):
+ raise SynapseError(400, f"{v!r} must be an integer.")
+ if k in {"events", "notifications", "users"}:
+ if not isinstance(v, dict) or not all(
+ isinstance(v, int) for v in v.values()
+ ):
+ raise SynapseError(
+ 400,
+ f"{v!r} must be a dict wherein all the values are integers.",
+ )
+
key = (event.type, event.state_key)
current_state = auth_events.get(key)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 98c203ada0..17f624b68f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -24,9 +24,11 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
+from synapse.storage.state import StateFilter
from synapse.types import EventID, JsonDict
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -120,8 +122,12 @@ class EventBuilder:
The signed and hashed event.
"""
if auth_event_ids is None:
- state_ids = await self._state.get_current_state_ids(
- self.room_id, prev_event_ids
+ state_ids = await self._state.compute_state_after_events(
+ self.room_id,
+ prev_event_ids,
+ state_filter=StateFilter.from_types(
+ auth_types_for_event(self.room_version, self)
+ ),
)
auth_event_ids = self._event_auth_handler.compute_auth_events(
self, state_ids
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5dfdc86740..ae550d3f4d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -118,6 +118,7 @@ class FederationServer(FederationBase):
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
+ self._room_member_handler = hs.get_room_member_handler()
self._state_storage_controller = hs.get_storage_controllers().state
@@ -621,6 +622,15 @@ class FederationServer(FederationBase):
)
raise IncompatibleRoomVersionError(room_version=room_version)
+ # Refuse the request if that room has seen too many joins recently.
+ # This is in addition to the HS-level rate limiting applied by
+ # BaseFederationServlet.
+ # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
+ await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
+ requester=None,
+ key=room_id,
+ update=False,
+ )
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
@@ -655,6 +665,12 @@ class FederationServer(FederationBase):
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
+ await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
+ requester=None,
+ key=room_id,
+ update=False,
+ )
+
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 99a794c042..94a65ac65f 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -351,7 +351,11 @@ class FederationSender(AbstractFederationSender):
self._is_processing = True
while True:
last_token = await self.store.get_federation_out_pos("events")
- next_token, events = await self.store.get_all_new_events_stream(
+ (
+ next_token,
+ events,
+ event_to_received_ts,
+ ) = await self.store.get_all_new_events_stream(
last_token, self._last_poked_id, limit=100
)
@@ -476,7 +480,7 @@ class FederationSender(AbstractFederationSender):
await self._send_pdu(event, sharded_destinations)
now = self.clock.time_msec()
- ts = await self.store.get_received_ts(event.event_id)
+ ts = event_to_received_ts[event.event_id]
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
"federation_sender"
@@ -509,7 +513,7 @@ class FederationSender(AbstractFederationSender):
if events:
now = self.clock.time_msec()
- ts = await self.store.get_received_ts(events[-1].event_id)
+ ts = event_to_received_ts[events[-1].event_id]
assert ts is not None
synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 814553e098..203b62e015 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -104,14 +104,15 @@ class ApplicationServicesHandler:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
- limit = 100
upper_bound = -1
while upper_bound < self.current_max:
+ last_token = await self.store.get_appservice_last_pos()
(
upper_bound,
events,
- ) = await self.store.get_new_events_for_appservice(
- self.current_max, limit
+ event_to_received_ts,
+ ) = await self.store.get_all_new_events_stream(
+ last_token, self.current_max, limit=100, get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}
@@ -150,7 +151,7 @@ class ApplicationServicesHandler:
)
now = self.clock.time_msec()
- ts = await self.store.get_received_ts(event.event_id)
+ ts = event_to_received_ts[event.event_id]
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
@@ -187,7 +188,7 @@ class ApplicationServicesHandler:
if events:
now = self.clock.time_msec()
- ts = await self.store.get_received_ts(events[-1].event_id)
+ ts = event_to_received_ts[events[-1].event_id]
assert ts is not None
synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c74117c19a..766d9849f5 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
import itertools
import logging
from http import HTTPStatus
@@ -347,7 +348,7 @@ class FederationEventHandler:
event.internal_metadata.send_on_behalf_of = origin
context = await self._state_handler.compute_event_context(event)
- context = await self._check_event_auth(origin, event, context)
+ await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(
403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -485,7 +486,7 @@ class FederationEventHandler:
partial_state=partial_state,
)
- context = await self._check_event_auth(origin, event, context)
+ await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(400, "Join event was rejected")
@@ -1116,11 +1117,7 @@ class FederationEventHandler:
state_ids_before_event=state_ids,
)
try:
- context = await self._check_event_auth(
- origin,
- event,
- context,
- )
+ await self._check_event_auth(origin, event, context)
except AuthError as e:
# This happens only if we couldn't find the auth events. We'll already have
# logged a warning, so now we just convert to a FederationError.
@@ -1495,11 +1492,8 @@ class FederationEventHandler:
)
async def _check_event_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- ) -> EventContext:
+ self, origin: str, event: EventBase, context: EventContext
+ ) -> None:
"""
Checks whether an event should be rejected (for failing auth checks).
@@ -1509,9 +1503,6 @@ class FederationEventHandler:
context:
The event context.
- Returns:
- The updated context object.
-
Raises:
AuthError if we were unable to find copies of the event's auth events.
(Most other failures just cause us to set `context.rejected`.)
@@ -1526,7 +1517,7 @@ class FederationEventHandler:
logger.warning("While validating received event %r: %s", event, e)
# TODO: use a different rejected reason here?
context.rejected = RejectedReason.AUTH_ERROR
- return context
+ return
# next, check that we have all of the event's auth events.
#
@@ -1538,6 +1529,9 @@ class FederationEventHandler:
)
# ... and check that the event passes auth at those auth events.
+ # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+ # 4. Passes authorization rules based on the event’s auth events,
+ # otherwise it is rejected.
try:
await check_state_independent_auth_rules(self._store, event)
check_state_dependent_auth_rules(event, claimed_auth_events)
@@ -1546,55 +1540,90 @@ class FederationEventHandler:
"While checking auth of %r against auth_events: %s", event, e
)
context.rejected = RejectedReason.AUTH_ERROR
- return context
+ return
+
+ # now check the auth rules pass against the room state before the event
+ # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+ # 5. Passes authorization rules based on the state before the event,
+ # otherwise it is rejected.
+ #
+ # ... however, if we only have partial state for the room, then there is a good
+ # chance that we'll be missing some of the state needed to auth the new event.
+ # So, we state-resolve the auth events that we are given against the state that
+ # we know about, which ensures things like bans are applied. (Note that we'll
+ # already have checked we have all the auth events, in
+ # _load_or_fetch_auth_events_for_event above)
+ if context.partial_state:
+ room_version = await self._store.get_room_version_id(event.room_id)
+
+ local_state_id_map = await context.get_prev_state_ids()
+ claimed_auth_events_id_map = {
+ (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+ }
+
+ state_for_auth_id_map = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version,
+ [local_state_id_map, claimed_auth_events_id_map],
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
+ )
+ else:
+ event_types = event_auth.auth_types_for_event(event.room_version, event)
+ state_for_auth_id_map = await context.get_prev_state_ids(
+ StateFilter.from_types(event_types)
+ )
- # now check auth against what we think the auth events *should* be.
- event_types = event_auth.auth_types_for_event(event.room_version, event)
- prev_state_ids = await context.get_prev_state_ids(
- StateFilter.from_types(event_types)
+ calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+ event, state_for_auth_id_map, for_verification=True
)
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
+ # if those are the same, we're done here.
+ if collections.Counter(event.auth_event_ids()) == collections.Counter(
+ calculated_auth_event_ids
+ ):
+ return
+
+ # otherwise, re-run the auth checks based on what we calculated.
+ calculated_auth_events = await self._store.get_events_as_list(
+ calculated_auth_event_ids
)
- auth_events_x = await self._store.get_events(auth_events_ids)
+
+ # log the differences
+
+ claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
calculated_auth_event_map = {
- (e.type, e.state_key): e for e in auth_events_x.values()
+ (e.type, e.state_key): e for e in calculated_auth_events
}
+ logger.info(
+ "event's auth_events are different to our calculated auth_events. "
+ "Claimed but not calculated: %s. Calculated but not claimed: %s",
+ [
+ ev
+ for k, ev in claimed_auth_event_map.items()
+ if k not in calculated_auth_event_map
+ or calculated_auth_event_map[k].event_id != ev.event_id
+ ],
+ [
+ ev
+ for k, ev in calculated_auth_event_map.items()
+ if k not in claimed_auth_event_map
+ or claimed_auth_event_map[k].event_id != ev.event_id
+ ],
+ )
try:
- updated_auth_events = await self._update_auth_events_for_auth(
+ check_state_dependent_auth_rules(event, calculated_auth_events)
+ except AuthError as e:
+ logger.warning(
+ "While checking auth of %r against room state before the event: %s",
event,
- calculated_auth_event_map=calculated_auth_event_map,
- )
- except Exception:
- # We don't really mind if the above fails, so lets not fail
- # processing if it does. However, it really shouldn't fail so
- # let's still log as an exception since we'll still want to fix
- # any bugs.
- logger.exception(
- "Failed to double check auth events for %s with remote. "
- "Ignoring failure and continuing processing of event.",
- event.event_id,
- )
- updated_auth_events = None
-
- if updated_auth_events:
- context = await self._update_context_for_auth_events(
- event, context, updated_auth_events
+ e,
)
- auth_events_for_auth = updated_auth_events
- else:
- auth_events_for_auth = calculated_auth_event_map
-
- try:
- check_state_dependent_auth_rules(event, auth_events_for_auth.values())
- except AuthError as e:
- logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
- return context
-
async def _maybe_kick_guest_users(self, event: EventBase) -> None:
if event.type != EventTypes.GuestAccess:
return
@@ -1704,93 +1733,6 @@ class FederationEventHandler:
soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True
- async def _update_auth_events_for_auth(
- self,
- event: EventBase,
- calculated_auth_event_map: StateMap[EventBase],
- ) -> Optional[StateMap[EventBase]]:
- """Helper for _check_event_auth. See there for docs.
-
- Checks whether a given event has the expected auth events. If it
- doesn't then we talk to the remote server to compare state to see if
- we can come to a consensus (e.g. if one server missed some valid
- state).
-
- This attempts to resolve any potential divergence of state between
- servers, but is not essential and so failures should not block further
- processing of the event.
-
- Args:
- event:
-
- calculated_auth_event_map:
- Our calculated auth_events based on the state of the room
- at the event's position in the DAG.
-
- Returns:
- updated auth event map, or None if no changes are needed.
-
- """
- assert not event.internal_metadata.outlier
-
- # check for events which are in the event's claimed auth_events, but not
- # in our calculated event map.
- event_auth_events = set(event.auth_event_ids())
- different_auth = event_auth_events.difference(
- e.event_id for e in calculated_auth_event_map.values()
- )
-
- if not different_auth:
- return None
-
- logger.info(
- "auth_events refers to events which are not in our calculated auth "
- "chain: %s",
- different_auth,
- )
-
- # XXX: currently this checks for redactions but I'm not convinced that is
- # necessary?
- different_events = await self._store.get_events_as_list(different_auth)
-
- # double-check they're all in the same room - we should already have checked
- # this but it doesn't hurt to check again.
- for d in different_events:
- assert (
- d.room_id == event.room_id
- ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
- # now we state-resolve between our own idea of the auth events, and the remote's
- # idea of them.
-
- local_state = calculated_auth_event_map.values()
- remote_auth_events = dict(calculated_auth_event_map)
- remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
- remote_state = remote_auth_events.values()
-
- room_version = await self._store.get_room_version_id(event.room_id)
- new_state = await self._state_handler.resolve_events(
- room_version, (local_state, remote_state), event
- )
- different_state = {
- (d.type, d.state_key): d
- for d in new_state.values()
- if calculated_auth_event_map.get((d.type, d.state_key)) != d
- }
- if not different_state:
- logger.info("State res returned no new state")
- return None
-
- logger.info(
- "After state res: updating auth_events with new state %s",
- different_state.values(),
- )
-
- # take a copy of calculated_auth_event_map before we modify it.
- auth_events = dict(calculated_auth_event_map)
- auth_events.update(different_state)
- return auth_events
-
async def _load_or_fetch_auth_events_for_event(
self, destination: str, event: EventBase
) -> Collection[EventBase]:
@@ -1888,61 +1830,6 @@ class FederationEventHandler:
await self._auth_and_persist_outliers(room_id, remote_auth_events)
- async def _update_context_for_auth_events(
- self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
- ) -> EventContext:
- """Update the state_ids in an event context after auth event resolution,
- storing the changes as a new state group.
-
- Args:
- event: The event we're handling the context for
-
- context: initial event context
-
- auth_events: Events to update in the event context.
-
- Returns:
- new event context
- """
- # exclude the state key of the new event from the current_state in the context.
- if event.is_state():
- event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
- else:
- event_key = None
- state_updates = {
- k: a.event_id for k, a in auth_events.items() if k != event_key
- }
-
- current_state_ids = await context.get_current_state_ids()
- current_state_ids = dict(current_state_ids) # type: ignore
-
- current_state_ids.update(state_updates)
-
- prev_state_ids = await context.get_prev_state_ids()
- prev_state_ids = dict(prev_state_ids)
-
- prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
- # create a new state group as a delta from the existing one.
- prev_group = context.state_group
- state_group = await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=prev_group,
- delta_ids=state_updates,
- current_state_ids=current_state_ids,
- )
-
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group=state_group,
- state_group_before_event=context.state_group_before_event,
- state_delta_due_to_event=state_updates,
- prev_group=prev_group,
- delta_ids=state_updates,
- partial_state=context.partial_state,
- )
-
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> None:
@@ -2093,6 +1980,10 @@ class FederationEventHandler:
event, event_pos, max_stream_token, extra_users=extra_users
)
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ # TODO retrieve the previous state, and exclude join -> join transitions
+ self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9bca2bc4b2..9571d461c8 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -26,7 +26,6 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
@@ -163,8 +162,7 @@ class IdentityHandler:
sid: str,
mxid: str,
id_server: str,
- id_access_token: Optional[str] = None,
- use_v2: bool = True,
+ id_access_token: str,
) -> JsonDict:
"""Bind a 3PID to an identity server
@@ -174,8 +172,7 @@ class IdentityHandler:
mxid: The MXID to bind the 3PID to
id_server: The domain of the identity server to query
id_access_token: The access token to authenticate to the identity
- server with, if necessary. Required if use_v2 is true
- use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
+ server with
Raises:
SynapseError: On any of the following conditions
@@ -187,24 +184,15 @@ class IdentityHandler:
"""
logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
- # If an id_access_token is not supplied, force usage of v1
- if id_access_token is None:
- use_v2 = False
-
if not valid_id_server_location(id_server):
raise SynapseError(
400,
"id_server must be a valid hostname with optional port and path components",
)
- # Decide which API endpoint URLs to use
- headers = {}
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
- if use_v2:
- bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
- headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
- else:
- bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+ bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+ headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
# Use the blacklisting http client as this call is only to identity servers
@@ -223,21 +211,14 @@ class IdentityHandler:
return data
except HttpResponseException as e:
- if e.code != 404 or not use_v2:
- logger.error("3PID bind failed with Matrix error: %r", e)
- raise e.to_synapse_error()
+ logger.error("3PID bind failed with Matrix error: %r", e)
+ raise e.to_synapse_error()
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json_decoder.decode(e.msg) # XXX WAT?
return data
- logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
- res = await self.bind_threepid(
- client_secret, sid, mxid, id_server, id_access_token, use_v2=False
- )
- return res
-
async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
@@ -300,8 +281,8 @@ class IdentityHandler:
"id_server must be a valid hostname with optional port and path components",
)
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
+ url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
+ url_bytes = b"/_matrix/identity/v2/3pid/unbind"
content = {
"mxid": mxid,
@@ -434,48 +415,6 @@ class IdentityHandler:
return session_id
- async def requestEmailToken(
- self,
- id_server: str,
- email: str,
- client_secret: str,
- send_attempt: int,
- next_link: Optional[str] = None,
- ) -> JsonDict:
- """
- Request an external server send an email on our behalf for the purposes of threepid
- validation.
-
- Args:
- id_server: The identity server to proxy to
- email: The email to send the message to
- client_secret: The unique client_secret sends by the user
- send_attempt: Which attempt this is
- next_link: A link to redirect the user to once they submit the token
-
- Returns:
- The json response body from the server
- """
- params = {
- "email": email,
- "client_secret": client_secret,
- "send_attempt": send_attempt,
- }
- if next_link:
- params["next_link"] = next_link
-
- try:
- data = await self.http_client.post_json_get_json(
- id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
- params,
- )
- return data
- except HttpResponseException as e:
- logger.info("Proxied requestToken failed: %r", e)
- raise e.to_synapse_error()
- except RequestTimedOutError:
- raise SynapseError(500, "Timed out contacting identity server")
-
async def requestMsisdnToken(
self,
id_server: str,
@@ -549,18 +488,7 @@ class IdentityHandler:
validation_session = None
# Try to validate as email
- if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- # Remote emails will only be used if a valid identity server is provided.
- assert (
- self.hs.config.registration.account_threepid_delegate_email is not None
- )
-
- # Ask our delegated email identity server
- validation_session = await self.threepid_from_creds(
- self.hs.config.registration.account_threepid_delegate_email,
- threepid_creds,
- )
- elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.hs.config.email.can_verify_email:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1980e37dae..bd7baef051 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -463,6 +463,7 @@ class EventCreationHandler:
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
+ self._notifier = hs.get_notifier()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@@ -1444,7 +1445,12 @@ class EventCreationHandler:
if state_entry.state_group in self._external_cache_joined_hosts_updates:
return
- joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+ state = await state_entry.get_state(
+ self._storage_controllers.state, StateFilter.all()
+ )
+ joined_hosts = await self.store.get_joined_hosts(
+ event.room_id, state, state_entry
+ )
# Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates.
@@ -1545,6 +1551,16 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ (
+ current_membership,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ event.state_key, event.room_id
+ )
+ if current_membership != Membership.JOIN:
+ self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
@@ -1844,13 +1860,8 @@ class EventCreationHandler:
# For each room we need to find a joined member we can use to send
# the dummy event with.
- latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- members = await self.state.get_current_users_in_room(
- room_id, latest_event_ids=latest_event_ids
- )
+ members = await self.store.get_local_users_in_room(room_id)
for user_id in members:
- if not self.hs.is_mine_id(user_id):
- continue
requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
@@ -1861,7 +1872,6 @@ class EventCreationHandler:
"room_id": room_id,
"sender": user_id,
},
- prev_event_ids=latest_event_ids,
)
event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a54f163c0a..978d3ee39f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -889,7 +889,11 @@ class RoomCreationHandler:
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
- last_stream_id = await self._send_events_for_new_room(
+ (
+ last_stream_id,
+ last_sent_event_id,
+ depth,
+ ) = await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@@ -905,7 +909,7 @@ class RoomCreationHandler:
if "name" in config:
name = config["name"]
(
- _,
+ name_event,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
@@ -917,12 +921,16 @@ class RoomCreationHandler:
"content": {"name": name},
},
ratelimit=False,
+ prev_event_ids=[last_sent_event_id],
+ depth=depth,
)
+ last_sent_event_id = name_event.event_id
+ depth += 1
if "topic" in config:
topic = config["topic"]
(
- _,
+ topic_event,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
@@ -934,7 +942,11 @@ class RoomCreationHandler:
"content": {"topic": topic},
},
ratelimit=False,
+ prev_event_ids=[last_sent_event_id],
+ depth=depth,
)
+ last_sent_event_id = topic_event.event_id
+ depth += 1
# we avoid dropping the lock between invites, as otherwise joins can
# start coming in and making the createRoom slow.
@@ -949,7 +961,7 @@ class RoomCreationHandler:
for invitee in invite_list:
(
- _,
+ member_event_id,
last_stream_id,
) = await self.room_member_handler.update_membership_locked(
requester,
@@ -959,7 +971,11 @@ class RoomCreationHandler:
ratelimit=False,
content=content,
new_room=True,
+ prev_event_ids=[last_sent_event_id],
+ depth=depth,
)
+ last_sent_event_id = member_event_id
+ depth += 1
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
@@ -968,7 +984,10 @@ class RoomCreationHandler:
medium = invite_3pid["medium"]
# Note that do_3pid_invite can raise a ShadowBanError, but this was
# handled above by emptying invite_3pid_list.
- last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
+ (
+ member_event_id,
+ last_stream_id,
+ ) = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -977,7 +996,11 @@ class RoomCreationHandler:
requester,
txn_id=None,
id_access_token=id_access_token,
+ prev_event_ids=[last_sent_event_id],
+ depth=depth,
)
+ last_sent_event_id = member_event_id
+ depth += 1
result = {"room_id": room_id}
@@ -1005,20 +1028,22 @@ class RoomCreationHandler:
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
ratelimit: bool = True,
- ) -> int:
+ ) -> Tuple[int, str, int]:
"""Sends the initial events into a new room.
`power_level_content_override` doesn't apply when initial state has
power level state event content.
Returns:
- The stream_id of the last event persisted.
+ A tuple containing the stream ID, event ID and depth of the last
+ event sent to the room.
"""
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
+ depth = 1
last_sent_event_id: Optional[str] = None
def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
@@ -1031,6 +1056,7 @@ class RoomCreationHandler:
async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
nonlocal last_sent_event_id
+ nonlocal depth
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
@@ -1047,9 +1073,11 @@ class RoomCreationHandler:
# Note: we don't pass state_event_ids here because this triggers
# an additional query per event to look them up from the events table.
prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
+ depth=depth,
)
last_sent_event_id = sent_event.event_id
+ depth += 1
return last_stream_id
@@ -1075,6 +1103,7 @@ class RoomCreationHandler:
content=creator_join_profile,
new_room=True,
prev_event_ids=[last_sent_event_id],
+ depth=depth,
)
last_sent_event_id = member_event_id
@@ -1168,7 +1197,7 @@ class RoomCreationHandler:
content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
)
- return last_sent_stream_id
+ return last_sent_stream_id, last_sent_event_id, depth
def _generate_room_id(self) -> str:
"""Generates a random room ID.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 04c44b2ccb..30b4cb23df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
+ # Tracks joins from local users to rooms this server isn't a member of.
+ # I.e. joins this server makes by requesting /make_join /send_join from
+ # another server.
self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ # TODO: find a better place to keep this Ratelimiter.
+ # It needs to be
+ # - written to by event persistence code
+ # - written to by something which can snoop on replication streams
+ # - read by the RoomMemberHandler to rate limit joins from local users
+ # - read by the FederationServer to rate limit make_joins and send_joins from
+ # other homeservers
+ # I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+ self._join_rate_per_room_limiter = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+ )
# Ratelimiter for invites, keyed by room (across all issuers, all
# recipients).
@@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
self.request_ratelimiter = hs.get_request_ratelimiter()
+ hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+ def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+ """Notify the rate limiter that a room join has occurred.
+
+ Use this to inform the RoomMemberHandler about joins that have either
+ - taken place on another homeserver, or
+ - on another worker in this homeserver.
+ Joins actioned by this worker should use the usual `ratelimit` method, which
+ checks the limit and increments the counter in one go.
+ """
+ self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
@abc.abstractmethod
async def _remote_join(
@@ -285,6 +314,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@@ -315,6 +345,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
txn_id:
ratelimit:
@@ -370,6 +403,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
require_consent=require_consent,
outlier=outlier,
historical=historical,
@@ -391,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# up blocking profile updates.
if newly_joined and ratelimit:
await self._join_rate_limiter_local.ratelimit(requester)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester, key=room_id, update=False
+ )
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
@@ -466,6 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -501,6 +539,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -540,6 +581,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
)
return result
@@ -562,6 +604,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -599,6 +642,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -732,6 +778,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
content=content,
require_consent=require_consent,
outlier=outlier,
@@ -740,14 +787,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- current_state_ids = await self.state_handler.get_current_state_ids(
- room_id, latest_event_ids=latest_event_ids
+ state_before_join = await self.state_handler.compute_state_after_events(
+ room_id, latest_event_ids
)
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
- old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+ old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@@ -798,11 +845,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = await self._is_host_in_room(current_state_ids)
+ is_host_in_room = await self._is_host_in_room(state_before_join)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = await self._can_guest_join(current_state_ids)
+ guest_can_join = await self._can_guest_join(state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -840,13 +887,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
- target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+ target.to_string(),
+ room_id,
+ remote_room_hosts,
+ content,
+ is_host_in_room,
+ state_before_join,
)
if remote_join:
if ratelimit:
await self._join_rate_limiter_remote.ratelimit(
requester,
)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester,
+ key=room_id,
+ update=False,
+ )
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
@@ -967,6 +1024,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
content=content,
require_consent=require_consent,
outlier=outlier,
@@ -979,6 +1037,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: List[str],
content: JsonDict,
is_host_in_room: bool,
+ state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@@ -998,6 +1057,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: The content to use as the event body of the join. This may
be modified.
is_host_in_room: True if the host is in the room.
+ state_before_join: The state before the join event (i.e. the resolution of
+ the states after its parent events).
Returns:
A tuple of:
@@ -1014,20 +1075,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
- current_state_ids = await self._storage_controllers.state.get_current_state_ids(
- room_id
- )
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
- current_state_ids, room_version
+ state_before_join, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
- prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1042,10 +1100,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
#
# If not, generate a new list of remote hosts based on which
# can issue invites.
- event_map = await self.store.get_events(current_state_ids.values())
+ event_map = await self.store.get_events(state_before_join.values())
current_state = {
state_key: event_map[event_id]
- for state_key, event_id in current_state_ids.items()
+ for state_key, event_id in state_before_join.items()
}
allowed_servers = get_servers_from_users(
get_users_which_can_issue_invite(current_state)
@@ -1059,7 +1117,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
- current_state_ids, room_version, user_id, prev_member_event
+ state_before_join, room_version, user_id, prev_member_event
)
# If this is going to be a local join, additional information must
@@ -1069,7 +1127,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
- current_state_ids,
+ state_before_join,
)
return False, []
@@ -1322,7 +1380,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
- ) -> int:
+ prev_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
+ ) -> Tuple[str, int]:
"""Invite a 3PID to a room.
Args:
@@ -1335,9 +1395,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: The transaction ID this is part of, or None if this is not
part of a transaction.
id_access_token: The optional identity server access token.
+ depth: Override the depth used to order the event in the DAG.
+ prev_event_ids: The event IDs to use as the prev events
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
- The new stream ID.
+ Tuple of event ID and stream ordering position
Raises:
ShadowBanError if the requester has been shadow-banned.
@@ -1383,7 +1447,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We don't check the invite against the spamchecker(s) here (through
# user_may_invite) because we'll do it further down the line anyway (in
# update_membership_locked).
- _, stream_id = await self.update_membership(
+ event_id, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
@@ -1402,7 +1466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
additional_fields=spam_check[1],
)
- stream_id = await self._make_and_store_3pid_invite(
+ event, stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@@ -1411,9 +1475,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
inviter,
txn_id=txn_id,
id_access_token=id_access_token,
+ prev_event_ids=prev_event_ids,
+ depth=depth,
)
+ event_id = event.event_id
- return stream_id
+ return event_id, stream_id
async def _make_and_store_3pid_invite(
self,
@@ -1425,7 +1492,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
- ) -> int:
+ prev_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
+ ) -> Tuple[EventBase, int]:
room_state = await self._storage_controllers.state.get_current_state(
room_id,
StateFilter.from_types(
@@ -1518,8 +1587,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
},
ratelimit=False,
txn_id=txn_id,
+ prev_event_ids=prev_event_ids,
+ depth=depth,
)
- return stream_id
+ return event, stream_id
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 05cebb5d4d..a744d68c64 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker:
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
- # msisdns are currently always ThreepidBehaviour.REMOTE
+ # msisdns are currently always verified via the IS
if medium == "msisdn":
if not self.hs.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
@@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker:
threepid_creds,
)
elif medium == "email":
- if (
- self.hs.config.email.threepid_behaviour_email
- == ThreepidBehaviour.REMOTE
- ):
- assert self.hs.config.registration.account_threepid_delegate_email
- threepid = await identity_handler.threepid_from_creds(
- self.hs.config.registration.account_threepid_delegate_email,
- threepid_creds,
- )
- elif (
- self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
- ):
+ if self.hs.config.email.can_verify_email:
threepid = None
row = await self.store.get_threepid_validation_session(
medium,
@@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
- return self.hs.config.email.threepid_behaviour_email in (
- ThreepidBehaviour.REMOTE,
- ThreepidBehaviour.LOCAL,
- )
+ return self.hs.config.email.can_verify_email
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("email", authdict)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 54b0ec4b97..c42bb8266a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -228,6 +228,7 @@ class Notifier:
# Called when there are new things to stream over replication
self.replication_callbacks: List[Callable[[], None]] = []
+ self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
self._federation_client = hs.get_federation_http_client()
@@ -280,6 +281,19 @@ class Notifier:
"""
self.replication_callbacks.append(cb)
+ def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
+ """Add a callback that will be called when a user joins a room.
+
+ This only fires on genuine membership changes, e.g. "invite" -> "join".
+ Membership transitions like "join" -> "join" (for e.g. displayname changes) do
+ not trigger the callback.
+
+ When called, the callback receives two arguments: the event ID and the room ID.
+ It should *not* return a Deferred - if it needs to do any asynchronous work, a
+ background thread should be started and wrapped with run_as_background_process.
+ """
+ self._new_join_in_room_callbacks.append(cb)
+
async def on_new_room_event(
self,
event: EventBase,
@@ -723,6 +737,10 @@ class Notifier:
for cb in self.replication_callbacks:
cb()
+ def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
+ for cb in self._new_join_in_room_callbacks:
+ cb(event_id, room_id)
+
def notify_remote_server_up(self, server: str) -> None:
"""Notify any replication that a remote server has come back up"""
# We call federation_sender directly rather than registering as a
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d0cc657b44..1e0ef44fc7 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -328,7 +328,7 @@ class PusherPool:
return None
try:
- p = self.pusher_factory.create_pusher(pusher_config)
+ pusher = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
@@ -346,23 +346,28 @@ class PusherPool:
)
return None
- if not p:
+ if not pusher:
return None
- appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
+ appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey)
- byuser = self.pushers.setdefault(pusher_config.user_name, {})
+ byuser = self.pushers.setdefault(pusher.user_id, {})
if appid_pushkey in byuser:
- byuser[appid_pushkey].on_stop()
- byuser[appid_pushkey] = p
+ previous_pusher = byuser[appid_pushkey]
+ previous_pusher.on_stop()
- synapse_pushers.labels(type(p).__name__, p.app_id).inc()
+ synapse_pushers.labels(
+ type(previous_pusher).__name__, previous_pusher.app_id
+ ).dec()
+ byuser[appid_pushkey] = pusher
+
+ synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
- user_id = pusher_config.user_name
- last_stream_ordering = pusher_config.last_stream_ordering
+ user_id = pusher.user_id
+ last_stream_ordering = pusher.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@@ -372,9 +377,9 @@ class PusherPool:
# risk missing push.
have_notifs = True
- p.on_started(have_notifs)
+ pusher.on_started(have_notifs)
- return p
+ return pusher
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.federation import send_queue
from synapse.federation.sender import FederationSender
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
membership=row.data.membership,
)
+ # If this event is a join, make a note of it so we have an accurate
+ # cross-worker room rate limit.
+ # TODO: Erik said we should exclude rows that came from ex_outliers
+ # here, but I don't see how we can determine that. I guess we could
+ # add a flag to row.data?
+ if (
+ row.data.type == EventTypes.Member
+ and row.data.membership == Membership.JOIN
+ and not row.data.outlier
+ ):
+ # TODO retrieve the previous state, and exclude join -> join transitions
+ self.notifier.notify_user_joined_room(
+ row.data.event_id, row.data.room_id
+ )
+
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
relates_to: Optional[str]
membership: Optional[str]
rejected: bool
+ outlier: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f0614a2897..ba2f7fa6d8 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -373,6 +373,7 @@ class UserRestServletV2(RestServlet):
if (
self.hs.config.email.email_enable_notifs
and self.hs.config.email.email_notif_for_new_users
+ and medium == "email"
):
await self.pusher_pool.add_pusher(
user_id=user_id,
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index bdc4a9c068..0cc87a4001 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -28,7 +28,6 @@ from synapse.api.errors import (
SynapseError,
ThreepidValidationError,
)
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
@@ -64,7 +63,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.config = hs.config
self.identity_handler = hs.get_identity_handler()
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -73,11 +72,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "User password resets have been disabled due to lack of email config"
- )
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "User password resets have been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based password resets have been disabled on this server"
)
@@ -129,35 +127,21 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.registration.account_threepid_delegate_email
-
- # Have the configured identity server handle the request
- ret = await self.identity_handler.requestEmailToken(
- self.hs.config.registration.account_threepid_delegate_email,
- email,
- client_secret,
- send_attempt,
- next_link,
- )
- else:
- # Send password reset emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_password_reset_mail,
- next_link,
- )
-
- # Wrap the session id in a JSON object
- ret = {"sid": sid}
+ # Send password reset emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_password_reset_mail,
+ next_link,
+ )
threepid_send_requests.labels(type="email", reason="password_reset").observe(
send_attempt
)
- return 200, ret
+ # Wrap the session id in a JSON object
+ return 200, {"sid": sid}
class PasswordRestServlet(RestServlet):
@@ -349,7 +333,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastores().main
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -358,11 +342,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Adding emails have been disabled due to lack of an email config"
- )
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
raise SynapseError(
400, "Adding an email to your account is disabled on this server"
)
@@ -413,35 +396,20 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.registration.account_threepid_delegate_email
-
- # Have the configured identity server handle the request
- ret = await self.identity_handler.requestEmailToken(
- self.hs.config.registration.account_threepid_delegate_email,
- email,
- client_secret,
- send_attempt,
- next_link,
- )
- else:
- # Send threepid validation emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_add_threepid_mail,
- next_link,
- )
-
- # Wrap the session id in a JSON object
- ret = {"sid": sid}
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_add_threepid_mail,
+ next_link,
+ )
threepid_send_requests.labels(type="email", reason="add_threepid").observe(
send_attempt
)
- return 200, ret
+ # Wrap the session id in a JSON object
+ return 200, {"sid": sid}
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -534,25 +502,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self._failure_email_template = (
self.config.email.email_add_threepid_template_failure_html
)
async def on_GET(self, request: Request) -> None:
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Adding emails have been disabled due to lack of an email config"
- )
- raise SynapseError(
- 400, "Adding an email to your account is disabled on this server"
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
)
- elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
raise SynapseError(
- 400,
- "This homeserver is not validating threepids. Use an identity server "
- "instead.",
+ 400, "Adding an email to your account is disabled on this server"
)
sid = parse_string(request, "sid", required=True)
@@ -743,10 +704,12 @@ class ThreepidBindRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+ assert_params_in_dict(
+ body, ["id_server", "sid", "id_access_token", "client_secret"]
+ )
id_server = body["id_server"]
sid = body["sid"]
- id_access_token = body.get("id_access_token") # optional
+ id_access_token = body["id_access_token"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index dd75e40f34..0437c87d8d 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -28,7 +28,7 @@ from typing import (
from typing_extensions import TypedDict
-from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
@@ -172,7 +172,13 @@ class LoginRestServlet(RestServlet):
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
- appservice = self.auth.get_appservice_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
+ appservice = requester.app_service
+
+ if appservice is None:
+ raise InvalidClientTokenError(
+ "This login method is only valid for application services"
+ )
if appservice.is_rate_limited():
await self._address_ratelimiter.ratelimit(
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 3644705e6a..8896f2df50 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
+ self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
+ if hs.config.experimental.msc2285_enabled:
+ self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
@@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet):
body = parse_json_object_from_request(request)
- valid_receipt_types = {
- ReceiptTypes.READ,
- ReceiptTypes.FULLY_READ,
- ReceiptTypes.READ_PRIVATE,
- }
-
- unrecognized_types = set(body.keys()) - valid_receipt_types
+ unrecognized_types = set(body.keys()) - self._known_receipt_types
if unrecognized_types:
# It's fine if there are unrecognized receipt types, but let's log
# it to help debug clients that have typoed the receipt type.
@@ -65,31 +63,25 @@ class ReadMarkerRestServlet(RestServlet):
# types.
logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
- read_event_id = body.get(ReceiptTypes.READ, None)
- if read_event_id:
- await self.receipts_handler.received_client_receipt(
- room_id,
- ReceiptTypes.READ,
- user_id=requester.user.to_string(),
- event_id=read_event_id,
- )
-
- read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
- if read_private_event_id and self.config.experimental.msc2285_enabled:
- await self.receipts_handler.received_client_receipt(
- room_id,
- ReceiptTypes.READ_PRIVATE,
- user_id=requester.user.to_string(),
- event_id=read_private_event_id,
- )
-
- read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
- if read_marker_event_id:
- await self.read_marker_handler.received_client_read_marker(
- room_id,
- user_id=requester.user.to_string(),
- event_id=read_marker_event_id,
- )
+ for receipt_type in self._known_receipt_types:
+ event_id = body.get(receipt_type, None)
+ # TODO Add validation to reject non-string event IDs.
+ if not event_id:
+ continue
+
+ if receipt_type == ReceiptTypes.FULLY_READ:
+ await self.read_marker_handler.received_client_read_marker(
+ room_id,
+ user_id=requester.user.to_string(),
+ event_id=event_id,
+ )
+ else:
+ await self.receipts_handler.received_client_receipt(
+ room_id,
+ receipt_type,
+ user_id=requester.user.to_string(),
+ event_id=event_id,
+ )
return 200, {}
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 4b03eb876b..409bfd43c1 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
+ self._known_receipt_types = {ReceiptTypes.READ}
+ if hs.config.experimental.msc2285_enabled:
+ self._known_receipt_types.update(
+ (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
+ )
+
async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.FULLY_READ,
- ]:
+ if receipt_type not in self._known_receipt_types:
raise SynapseError(
400,
- "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
+ f"Receipt type must be {', '.join(self._known_receipt_types)}",
)
- elif (
- not self.hs.config.experimental.msc2285_enabled
- and receipt_type != ReceiptTypes.READ
- ):
- raise SynapseError(400, "Receipt type must be 'm.read'")
parse_json_object_from_request(request, allow_empty_body=False)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index e8e51a9c66..a8402cdb3a 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.server import is_threepid_reserved
@@ -74,7 +73,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.config = hs.config
- if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.hs.config.email.can_verify_email:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -83,13 +82,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if (
- self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
- ):
- logger.warning(
- "Email registration has been disabled due to lack of email config"
- )
+ if not self.hs.config.email.can_verify_email:
+ logger.warning(
+ "Email registration has been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based registration has been disabled on this server"
)
@@ -138,35 +134,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.registration.account_threepid_delegate_email
-
- # Have the configured identity server handle the request
- ret = await self.identity_handler.requestEmailToken(
- self.hs.config.registration.account_threepid_delegate_email,
- email,
- client_secret,
- send_attempt,
- next_link,
- )
- else:
- # Send registration emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_registration_mail,
- next_link,
- )
-
- # Wrap the session id in a JSON object
- ret = {"sid": sid}
+ # Send registration emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_registration_mail,
+ next_link,
+ )
threepid_send_requests.labels(type="email", reason="register").observe(
send_attempt
)
- return 200, ret
+ # Wrap the session id in a JSON object
+ return 200, {"sid": sid}
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -260,7 +242,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self._failure_email_template = (
self.config.email.email_registration_template_failure_html
)
@@ -270,11 +252,10 @@ class RegistrationSubmitTokenServlet(RestServlet):
raise SynapseError(
400, "This medium is currently not supported for registration"
)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "User registration via email has been disabled due to lack of email config"
- )
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "User registration via email has been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based registration is disabled on this server"
)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 54a849eac9..b36c98a08e 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -109,10 +109,64 @@ class MediaInfo:
class PreviewUrlResource(DirectServeJsonResource):
"""
- Generating URL previews is a complicated task which many potential pitfalls.
-
- See docs/development/url_previews.md for discussion of the design and
- algorithm followed in this module.
+ The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API
+ for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+ specific additions).
+
+ This does have trade-offs compared to other designs:
+
+ * Pros:
+ * Simple and flexible; can be used by any clients at any point
+ * Cons:
+ * If each homeserver provides one of these independently, all the homeservers in a
+ room may needlessly DoS the target URI
+ * The URL metadata must be stored somewhere, rather than just using Matrix
+ itself to store the media.
+ * Matrix cannot be used to distribute the metadata between homeservers.
+
+ When Synapse is asked to preview a URL it does the following:
+
+ 1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the
+ config).
+ 2. Checks the URL against an in-memory cache and returns the result if it exists. (This
+ is also used to de-duplicate processing of multiple in-flight requests at once.)
+ 3. Kicks off a background process to generate a preview:
+ 1. Checks URL and timestamp against the database cache and returns the result if it
+ has not expired and was successful (a 2xx return code).
+ 2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it
+ does, update the URL to download.
+ 3. Downloads the URL and stores it into a file via the media storage provider
+ and saves the local media metadata.
+ 4. If the media is an image:
+ 1. Generates thumbnails.
+ 2. Generates an Open Graph response based on image properties.
+ 5. If the media is HTML:
+ 1. Decodes the HTML via the stored file.
+ 2. Generates an Open Graph response from the HTML.
+ 3. If a JSON oEmbed URL was found in the HTML via autodiscovery:
+ 1. Downloads the URL and stores it into a file via the media storage provider
+ and saves the local media metadata.
+ 2. Convert the oEmbed response to an Open Graph response.
+ 3. Override any Open Graph data from the HTML with data from oEmbed.
+ 4. If an image exists in the Open Graph response:
+ 1. Downloads the URL and stores it into a file via the media storage
+ provider and saves the local media metadata.
+ 2. Generates thumbnails.
+ 3. Updates the Open Graph response based on image properties.
+ 6. If the media is JSON and an oEmbed URL was found:
+ 1. Convert the oEmbed response to an Open Graph response.
+ 2. If a thumbnail or image is in the oEmbed response:
+ 1. Downloads the URL and stores it into a file via the media storage
+ provider and saves the local media metadata.
+ 2. Generates thumbnails.
+ 3. Updates the Open Graph response based on image properties.
+ 7. Stores the result in the database cache.
+ 4. Returns the result.
+
+ The in-memory cache expires after 1 hour.
+
+ Expired entries in the database cache (and their associated media files) are
+ deleted every 10 seconds. The default expiration time is 1 hour from download.
"""
isLeaf = True
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 2295adfaa7..5f725c7600 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,9 +17,11 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.http.server import (
DirectServeJsonResource,
+ respond_with_json,
set_corp_headers,
set_cors_headers,
)
@@ -309,6 +311,19 @@ class ThumbnailResource(DirectServeJsonResource):
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
+ logger.debug(
+ "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ thumbnail_infos,
+ )
+
+ # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+ # different code path to handle it.
+ assert not self.dynamic_thumbnails
+
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
@@ -384,8 +399,29 @@ class ThumbnailResource(DirectServeJsonResource):
file_info.thumbnail.length,
)
else:
+ # This might be because:
+ # 1. We can't create thumbnails for the given media (corrupted or
+ # unsupported file type), or
+ # 2. The thumbnailing process never ran or errored out initially
+ # when the media was first uploaded (these bugs should be
+ # reported and fixed).
+ # Note that we don't attempt to generate a thumbnail now because
+ # `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
- respond_404(request)
+
+ respond_with_json(
+ request,
+ 400,
+ cs_error(
+ "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+ % (
+ request.postpath,
+ ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+ ),
+ code=Codes.UNKNOWN,
+ ),
+ send_cors=True,
+ )
def _select_thumbnail(
self,
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 6ac9dbc7c9..b9402cfb75 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import DirectServeHtmlResource
from synapse.http.servlet import parse_string
from synapse.util.stringutils import assert_valid_client_secret
@@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- self._local_threepid_handling_disabled_due_to_email_config = (
- hs.config.email.local_threepid_handling_disabled_due_to_email_config
- )
self._confirmation_email_template = (
hs.config.email.email_password_reset_template_confirmation_html
)
@@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
hs.config.email.email_password_reset_template_failure_html
)
- # This resource should not be mounted if threepid behaviour is not LOCAL
- assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ # This resource should only be mounted if email validation is enabled
+ assert hs.config.email.can_verify_email
async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
sid = parse_string(request, "sid", required=True)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 781d9f06da..e3faa52cd6 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -24,14 +24,12 @@ from typing import (
DefaultDict,
Dict,
FrozenSet,
- Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
- Union,
)
import attr
@@ -47,6 +45,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -54,6 +53,7 @@ from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -83,17 +83,23 @@ def _gen_state_id() -> str:
class _StateCacheEntry:
- __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+ __slots__ = ["_state", "state_group", "prev_group", "delta_ids"]
def __init__(
self,
- state: StateMap[str],
+ state: Optional[StateMap[str]],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
+ if state is None and state_group is None:
+ raise Exception("Either state or state group must be not None")
+
# A map from (type, state_key) to event_id.
- self.state = frozendict(state)
+ #
+ # This can be None if we have a `state_group` (as then we can fetch the
+ # state from the DB.)
+ self._state = frozendict(state) if state is not None else None
# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
@@ -102,20 +108,30 @@ class _StateCacheEntry:
self.prev_group = prev_group
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
- # The `state_id` is a unique ID we generate that can be used as ID for
- # this collection of state. Usually this would be the same as the
- # state group, but on worker instances we can't generate a new state
- # group each time we resolve state, so we generate a separate one that
- # isn't persisted and is used solely for caches.
- # `state_id` is either a state_group (and so an int) or a string. This
- # ensures we don't accidentally persist a state_id as a stateg_group
- if state_group:
- self.state_id: Union[str, int] = state_group
- else:
- self.state_id = _gen_state_id()
+ async def get_state(
+ self,
+ state_storage: "StateStorageController",
+ state_filter: Optional["StateFilter"] = None,
+ ) -> StateMap[str]:
+ """Get the state map for this entry, either from the in-memory state or
+ looking up the state group in the DB.
+ """
+
+ if self._state is not None:
+ return self._state
+
+ assert self.state_group is not None
+
+ return await state_storage.get_state_ids_for_group(
+ self.state_group, state_filter
+ )
def __len__(self) -> int:
- return len(self.state)
+ # The len should is used to estimate how large this cache entry is, for
+ # cache eviction purposes. This is why if `self.state` is None it's fine
+ # to return 1.
+
+ return len(self._state) if self._state else 1
class StateHandler:
@@ -137,23 +153,29 @@ class StateHandler:
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
)
- async def get_current_state_ids(
+ async def compute_state_after_events(
self,
room_id: str,
- latest_event_ids: Collection[str],
+ event_ids: Collection[str],
+ state_filter: Optional[StateFilter] = None,
) -> StateMap[str]:
- """Get the current state, or the state at a set of events, for a room
+ """Fetch the state after each of the given event IDs. Resolve them and return.
+
+ This is typically used where `event_ids` is a collection of forward extremities
+ in a room, intended to become the `prev_events` of a new event E. If so, the
+ return value of this function represents the state before E.
Args:
- room_id:
- latest_event_ids: The forward extremities to resolve.
+ room_id: the room_id containing the given events.
+ event_ids: the events whose state should be fetched and resolved.
Returns:
- the state dict, mapping from (event_type, state_key) -> event_id
+ the state dict (a mapping from (event_type, state_key) -> event_id) which
+ holds the resolution of the states after the given event IDs.
"""
- logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- return ret.state
+ logger.debug("calling resolve_state_groups from compute_state_after_events")
+ ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+ return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: List[str]
@@ -177,7 +199,8 @@ class StateHandler:
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- return await self.store.get_joined_users_from_state(room_id, entry)
+ state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+ return await self.store.get_joined_users_from_state(room_id, state, entry)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
@@ -192,7 +215,8 @@ class StateHandler:
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
- return await self.store.get_joined_hosts(room_id, entry)
+ state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+ return await self.store.get_joined_hosts(room_id, state, entry)
async def compute_event_context(
self,
@@ -227,10 +251,19 @@ class StateHandler:
#
if state_ids_before_event:
# if we're given the state before the event, then we use that
- state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
- entry = None
+
+ # .. though we need to get a state group for it.
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=None,
+ delta_ids=None,
+ current_state_ids=state_ids_before_event,
+ )
+ )
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -264,36 +297,32 @@ class StateHandler:
await_full_state=False,
)
- state_ids_before_event = entry.state
- state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
+ state_ids_before_event = None
+
+ # We make sure that we have a state group assigned to the state.
+ if entry.state_group is None:
+ # store_state_group requires us to have either a previous state group
+ # (with deltas) or the complete state map. So, if we don't have a
+ # previous state group, load the complete state map now.
+ if state_group_before_event_prev_group is None:
+ state_ids_before_event = await entry.get_state(
+ self._state_storage_controller, StateFilter.all()
+ )
- #
- # make sure that we have a state group at that point. If it's not a state event,
- # that will be the state group for the new event. If it *is* a state event,
- # it might get rejected (in which case we'll need to persist it with the
- # previous state group)
- #
-
- if not state_group_before_event:
- state_group_before_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- current_state_ids=state_ids_before_event,
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
)
- )
-
- # Assign the new state group to the cached state entry.
- #
- # Note that this can race in that we could generate multiple state
- # groups for the same state entry, but that is just inefficient
- # rather than dangerous.
- if entry and entry.state_group is None:
entry.state_group = state_group_before_event
+ else:
+ state_group_before_event = entry.state_group
#
# now if it's not a state event, we're done
@@ -315,13 +344,18 @@ class StateHandler:
#
key = (event.type, event.state_key)
- if key in state_ids_before_event:
- replaces = state_ids_before_event[key]
- if replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
- state_ids_after_event = dict(state_ids_before_event)
- state_ids_after_event[key] = event.event_id
+ if state_ids_before_event is not None:
+ replaces = state_ids_before_event.get(key)
+ else:
+ replaces_state_map = await entry.get_state(
+ self._state_storage_controller, StateFilter.from_types([key])
+ )
+ replaces = replaces_state_map.get(key)
+
+ if replaces and replaces != event.event_id:
+ event.unsigned["replaces_state"] = replaces
+
delta_ids = {key: event.event_id}
state_group_after_event = (
@@ -330,7 +364,7 @@ class StateHandler:
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
- current_state_ids=state_ids_after_event,
+ current_state_ids=None,
)
)
@@ -372,9 +406,6 @@ class StateHandler:
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
- state = await self._state_storage_controller.get_state_for_groups(
- state_group_ids_set
- )
(
prev_group,
delta_ids,
@@ -382,7 +413,7 @@ class StateHandler:
state_group_id
)
return _StateCacheEntry(
- state=state[state_group_id],
+ state=None,
state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
@@ -405,31 +436,6 @@ class StateHandler:
)
return result
- async def resolve_events(
- self,
- room_version: str,
- state_sets: Collection[Iterable[EventBase]],
- event: EventBase,
- ) -> StateMap[EventBase]:
- logger.info(
- "Resolving state for %s with %d groups", event.room_id, len(state_sets)
- )
- state_set_ids = [
- {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
- ]
-
- state_map = {ev.event_id: ev for st in state_sets for ev in st}
-
- new_state = await self._state_resolution_handler.resolve_events_with_store(
- event.room_id,
- room_version,
- state_set_ids,
- event_map=state_map,
- state_res_store=StateResolutionStore(self.store),
- )
-
- return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-
async def update_current_state(self, room_id: str) -> None:
"""Recalculates the current state for a room, and persists it.
@@ -752,6 +758,12 @@ def _make_state_cache_entry(
delta_ids: Optional[StateMap[str]] = None
for old_group, old_state in state_groups_ids.items():
+ if old_state.keys() - new_state.keys():
+ # Currently we don't support deltas that remove keys from the state
+ # map, so we have to ignore this group as a candidate to base the
+ # new group on.
+ continue
+
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b8c8dcd76b..a2f8310388 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta):
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
+ Note that this function does not invalidate any remote caches, only the
+ local in-memory ones. Any remote invalidation must be performed before
+ calling this.
+
Args:
cache_name
key: Entry to invalidate. If None then invalidates the entire
@@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta):
if key is None:
cache.invalidate_all()
else:
- cache.invalidate(tuple(key))
+ # Prefer any local-only invalidation method. Invalidating any non-local
+ # cache must be be done before this.
+ invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
+ invalidate_method(tuple(key))
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
self.persistence = None
if stores.persist_events:
- self.persistence = EventsPersistenceStorageController(hs, stores)
+ self.persistence = EventsPersistenceStorageController(
+ hs, stores, self.state
+ )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..cf98b0ab48 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
current state and forward extremity changes.
"""
- def __init__(self, hs: "HomeServer", stores: Databases):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ stores: Databases,
+ state_controller: StateStorageController,
+ ):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
@@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
self._process_event_persist_queue_task
)
self._state_resolution_handler = hs.get_state_resolution_handler()
+ self._state_controller = state_controller
async def _process_event_persist_queue_task(
self,
@@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state
+ return await res.get_state(self._state_controller, StateFilter.all())
async def _persist_event_batch(
self, _room_id: str, task: _PersistEventsTask
@@ -940,7 +948,8 @@ class EventsPersistenceStorageController:
events_context,
)
- return res.state, None, new_latest_event_ids
+ full_state = await res.get_state(self._state_controller)
+ return full_state, None, new_latest_event_ids
async def _prune_extremities(
self,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index d3a44bc876..e08f956e6e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -346,7 +346,7 @@ class StateStorageController:
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
- current_state_ids: StateMap[str],
+ current_state_ids: Optional[StateMap[str]],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e21ab08515..ea672ff89e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
Dict,
@@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+ async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
@@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
name=txn_name,
database_engine=self.engine,
after_callbacks=after_callbacks,
+ async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
@@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+ Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
P = ParamSpec("P")
R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
+ async_after_callbacks: A list that asynchronous callbacks will be appended
+ to by `async_call_after` which should run, before after_callbacks, on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
"name",
"database_engine",
"after_callbacks",
+ "async_after_callbacks",
"exception_callbacks",
]
@@ -247,12 +258,14 @@ class LoggingTransaction:
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
+ self.async_after_callbacks = async_after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(
@@ -277,6 +290,28 @@ class LoggingTransaction:
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
+ def async_call_after(
+ self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+ ) -> None:
+ """Call the given asynchronous callback on the main twisted thread after
+ the transaction has finished (but before those added in `call_after`).
+
+ Mostly used to invalidate remote caches after transactions.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `async_call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
+ """
+ # if self.async_after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.async_after_callbacks is not None
+ # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+ self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
+
def call_on_exception(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
@@ -574,6 +609,7 @@ class DatabasePool:
conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
+ async_after_callbacks: List[_AsyncCallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
conn
desc
after_callbacks
+ async_after_callbacks
exception_callbacks
func
*args
@@ -659,6 +696,7 @@ class DatabasePool:
cursor = conn.cursor(
txn_name=name,
after_callbacks=after_callbacks,
+ async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
try:
@@ -798,6 +836,7 @@ class DatabasePool:
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
+ async_after_callbacks: List[_AsyncCallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
@@ -809,6 +848,7 @@ class DatabasePool:
self.new_transaction,
desc,
after_callbacks,
+ async_after_callbacks,
exception_callbacks,
func,
*args,
@@ -817,13 +857,17 @@ class DatabasePool:
**kwargs,
)
+ # We order these assuming that async functions call out to external
+ # systems (e.g. to invalidate a cache) and the sync functions make these
+ # changes on any local in-memory caches/similar, and thus must be second.
+ for async_callback, async_args, async_kwargs in async_after_callbacks:
+ await async_callback(*async_args, **async_kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
-
return cast(R, result)
except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
+ for exception_callback, after_args, after_kwargs in exception_callbacks:
+ exception_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore(
device_list_summary=DeviceListUpdates(),
)
- async def set_appservice_last_pos(self, pos: int) -> None:
- def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
- txn.execute(
- "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
- )
+ async def get_appservice_last_pos(self) -> int:
+ """
+ Get the last stream ordering position for the appservice process.
+ """
- await self.db_pool.runInteraction(
- "set_appservice_last_pos", set_appservice_last_pos_txn
+ return await self.db_pool.simple_select_one_onecol(
+ table="appservice_stream_position",
+ retcol="stream_ordering",
+ keyvalues={},
+ desc="get_appservice_last_pos",
)
- async def get_new_events_for_appservice(
- self, current_id: int, limit: int
- ) -> Tuple[int, List[EventBase]]:
- """Get all new events for an appservice"""
-
- def get_new_events_for_appservice_txn(
- txn: LoggingTransaction,
- ) -> Tuple[int, List[str]]:
- sql = (
- "SELECT e.stream_ordering, e.event_id"
- " FROM events AS e"
- " WHERE"
- " (SELECT stream_ordering FROM appservice_stream_position)"
- " < e.stream_ordering"
- " AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- " LIMIT ?"
- )
-
- txn.execute(sql, (current_id, limit))
- rows = txn.fetchall()
-
- upper_bound = current_id
- if len(rows) == limit:
- upper_bound = rows[-1][0]
-
- return upper_bound, [row[1] for row in rows]
+ async def set_appservice_last_pos(self, pos: int) -> None:
+ """
+ Set the last stream ordering position for the appservice process.
+ """
- upper_bound, event_ids = await self.db_pool.runInteraction(
- "get_new_events_for_appservice", get_new_events_for_appservice_txn
+ await self.db_pool.simple_update_one(
+ table="appservice_stream_position",
+ keyvalues={},
+ updatevalues={"stream_ordering": pos},
+ desc="set_appservice_last_pos",
)
- events = await self.get_events_as_list(event_ids, get_prev_content=True)
-
- return upper_bound, events
-
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..2367ddeea3 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -193,7 +193,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
relates_to: Optional[str],
backfilled: bool,
) -> None:
- self._invalidate_get_event_cache(event_id)
+ # This invalidates any local in-memory cached event objects, the original
+ # process triggering the invalidation is responsible for clearing any external
+ # cached objects.
+ self._invalidate_local_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -208,7 +211,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
- self._invalidate_get_event_cache(redacts)
+ self._invalidate_local_get_event_cache(redacts)
# Caches which might leak edits must be invalidated for the event being
# redacted.
self.get_relations_for_event.invalidate((redacts,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
- txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ self.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index eb4efbb93c..156e1bd5ab 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1293,7 +1293,7 @@ class PersistEventsStore:
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
- txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+ self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
@@ -1669,13 +1669,13 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
- def prefill() -> None:
+ async def prefill() -> None:
for cache_entry in to_prefill:
- self.store._get_event_cache.set(
+ await self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
)
- txn.call_after(prefill)
+ txn.async_call_after(prefill)
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Invalidate the caches for the redacted event.
@@ -1684,7 +1684,7 @@ class PersistEventsStore:
_invalidate_caches_for_event.
"""
assert event.redacts is not None
- txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+ self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index eeca85fc94..6e8aeed7b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -67,6 +67,8 @@ class _BackgroundUpdates:
EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
+ EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
@@ -253,6 +255,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
replaces_index="ev_edges_id",
)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+ self._background_events_populate_state_key_rejections,
+ )
+
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1399,3 +1406,83 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
return batch_size
+
+ async def _background_events_populate_state_key_rejections(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Back-populate `events.state_key` and `events.rejection_reason"""
+
+ min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
+ max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
+
+ def _populate_txn(txn: LoggingTransaction) -> bool:
+ """Returns True if we're done."""
+
+ # first we need to find an endpoint.
+ # we need to find the final row in the batch of batch_size, which means
+ # we need to skip over (batch_size-1) rows and get the next row.
+ txn.execute(
+ """
+ SELECT stream_ordering FROM events
+ WHERE stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ LIMIT 1 OFFSET ?
+ """,
+ (
+ min_stream_ordering_exclusive,
+ max_stream_ordering_inclusive,
+ batch_size - 1,
+ ),
+ )
+
+ endpoint = None
+ row = txn.fetchone()
+ if row:
+ endpoint = row[0]
+
+ where_clause = "stream_ordering > ?"
+ args = [min_stream_ordering_exclusive]
+ if endpoint:
+ where_clause += " AND stream_ordering <= ?"
+ args.append(endpoint)
+
+ # now do the updates.
+ txn.execute(
+ f"""
+ UPDATE events
+ SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
+ rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
+ WHERE ({where_clause})
+ """,
+ args,
+ )
+
+ logger.info(
+ "populated new `events` columns up to %s/%i: updated %i rows",
+ endpoint,
+ max_stream_ordering_inclusive,
+ txn.rowcount,
+ )
+
+ if endpoint is None:
+ # we're done
+ return True
+
+ progress["min_stream_ordering_exclusive"] = endpoint
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+ progress,
+ )
+ return False
+
+ done = await self.db_pool.runInteraction(
+ desc="events_populate_state_key_rejections", func=_populate_txn
+ )
+
+ if done:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
+ )
+
+ return batch_size
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..5914a35420 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -79,7 +79,7 @@ from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -238,7 +238,9 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+ self._get_event_cache: AsyncLruCache[
+ Tuple[str], EventCacheEntry
+ ] = AsyncLruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -292,25 +294,6 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- async def get_received_ts(self, event_id: str) -> Optional[int]:
- """Get received_ts (when it was persisted) for the event.
-
- Raises an exception for unknown events.
-
- Args:
- event_id: The event ID to query.
-
- Returns:
- Timestamp in milliseconds, or None for events that were persisted
- before received_ts was implemented.
- """
- return await self.db_pool.simple_select_one_onecol(
- table="events",
- keyvalues={"event_id": event_id},
- retcol="received_ts",
- desc="get_received_ts",
- )
-
async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
from the database due to a redaction.
@@ -617,7 +600,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
map from event id to result
"""
- event_entry_map = self._get_events_from_cache(
+ event_entry_map = await self._get_events_from_cache(
event_ids,
)
@@ -729,12 +712,46 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id: str) -> None:
- self._get_event_cache.invalidate((event_id,))
+ def invalidate_get_event_cache_after_txn(
+ self, txn: LoggingTransaction, event_id: str
+ ) -> None:
+ """
+ Prepares a database transaction to invalidate the get event cache for a given
+ event ID when executed successfully. This is achieved by attaching two callbacks
+ to the transaction, one to invalidate the async cache and one for the in memory
+ sync cache (importantly called in that order).
+
+ Arguments:
+ txn: the database transaction to attach the callbacks to
+ event_id: the event ID to be invalidated from caches
+ """
+
+ txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+ txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+ async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+ """
+ Invalidates an event in the asyncronous get event cache, which may be remote.
+
+ Arguments:
+ event_id: the event ID to invalidate
+ """
+
+ await self._get_event_cache.invalidate((event_id,))
+
+ def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+ """
+ Invalidates an event in local in-memory get event caches.
+
+ Arguments:
+ event_id: the event ID to invalidate
+ """
+
+ self._get_event_cache.invalidate_local((event_id,))
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
- def _get_events_from_cache(
+ async def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
@@ -749,7 +766,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events:
# First check if it's in the event cache
- ret = self._get_event_cache.get(
+ ret = await self._get_event_cache.get(
(event_id,), None, update_metrics=update_metrics
)
if ret:
@@ -771,7 +788,7 @@ class EventsWorkerStore(SQLBaseStore):
# We add the entry back into the cache as we want to keep
# recently queried events in the cache.
- self._get_event_cache.set((event_id,), cache_entry)
+ await self._get_event_cache.set((event_id,), cache_entry)
return event_map
@@ -965,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore):
}
row_dict = self.db_pool.new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+ conn,
+ "do_fetch",
+ [],
+ [],
+ [],
+ self._fetch_event_rows,
+ events_to_fetch,
)
# We only want to resolve deferreds from the main thread
@@ -1148,7 +1171,7 @@ class EventsWorkerStore(SQLBaseStore):
event=original_ev, redacted_event=redacted_event
)
- self._get_event_cache.set((event_id,), cache_entry)
+ await self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
if not redacted_event:
@@ -1382,7 +1405,9 @@ class EventsWorkerStore(SQLBaseStore):
# if the event cache contains the event, obviously we've seen it.
cache_results = {
- (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+ (rid, eid)
+ for (rid, eid) in keys
+ if await self._get_event_cache.contains((eid,))
}
results = dict.fromkeys(cache_results, True)
remaining = [k for k in keys if k not in cache_results]
@@ -1465,7 +1490,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1481,10 +1506,11 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+ " e.outlier"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1498,7 +1524,8 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+ txn.fetchall(),
)
return await self.db_pool.runInteraction(
@@ -1507,7 +1534,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1522,11 +1549,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+ " e.outlier"
" FROM events AS e"
+ # NB: the next line (inner join) is what makes this query different from
+ # get_all_new_forward_event_rows.
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1541,7 +1571,8 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, instance_name))
return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+ txn.fetchall(),
)
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
"initialise_mau_threepids",
[],
[],
+ [],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 87b0d09039..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
@@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
- self._invalidate_get_event_cache(event_id)
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
logger.info("[purge] done")
@@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
Returns:
The list of state groups to delete.
"""
- return await self.db_pool.runInteraction(
- "purge_room", self._purge_room_txn, room_id
+
+ # This first runs the purge transaction with READ_COMMITTED isolation level,
+ # meaning any new rows in the tables will not trigger a serialization error.
+ # We then run the same purge a second time without this isolation level to
+ # purge any of those rows which were added during the first.
+
+ state_groups_to_delete = await self.db_pool.runInteraction(
+ "purge_room",
+ self._purge_room_txn,
+ room_id=room_id,
+ isolation_level=IsolationLevel.READ_COMMITTED,
+ )
+
+ state_groups_to_delete.extend(
+ await self.db_pool.runInteraction(
+ "purge_room",
+ self._purge_room_txn,
+ room_id=room_id,
+ ),
)
+ return state_groups_to_delete
+
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
+ # This collides with event persistence so we cannot write new events and metadata into
+ # a room while deleting it or this transaction will fail.
+ if isinstance(self.database_engine, PostgresEngine):
+ txn.execute(
+ "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
+ (room_id,),
+ )
+
# First, fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 86649c1e6c..768f95d16c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -228,6 +228,7 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("*",),
desc="bulk_get_push_rules",
+ batch_size=1000,
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 13d6a1d5c0..d6d485507b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -175,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
rooms.creator, state.encryption, state.is_federatable AS federatable,
rooms.is_public AS public, state.join_rules, state.guest_access,
state.history_visibility, curr.current_state_events AS state_events,
- state.avatar, state.topic
+ state.avatar, state.topic, state.room_type
FROM rooms
LEFT JOIN room_stats_state state USING (room_id)
LEFT JOIN room_stats_current curr USING (room_id)
@@ -596,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room, rooms.room_version, rooms.creator,
state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
- state.guest_access, state.history_visibility, curr.current_state_events
+ state.guest_access, state.history_visibility, curr.current_state_events,
+ state.room_type
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
INNER JOIN rooms USING (room_id)
@@ -646,6 +647,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"guest_access": room[11],
"history_visibility": room[12],
"state_events": room[13],
+ "room_type": room[14],
}
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..df6b82660e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ import attr
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -244,7 +243,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn: LoggingTransaction,
) -> Dict[str, ProfileInfo]:
clause, ids = make_in_list_sql_clause(
- self.database_engine, "m.user_id", user_ids
+ self.database_engine, "c.state_key", user_ids
)
sql = """
@@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_context(
- self, event: EventBase, context: EventContext
- ) -> Dict[str, ProfileInfo]:
- state_group: Union[object, int] = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = await context.get_current_state_ids()
- assert current_state_ids is not None
- assert state_group is not None
- return await self._get_joined_users_from_context(
- event.room_id, state_group, current_state_ids, event=event, context=context
- )
-
async def get_joined_users_from_state(
- self, room_id: str, state_entry: "_StateCacheEntry"
+ self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
@@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
+ room_id, state_group, state, context=state_entry
)
- @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+ @cached(num_args=2, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
- cache_context: _CacheContext,
event: Optional[EventBase] = None,
- context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
+ context: Optional["_StateCacheEntry"] = None,
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
@@ -863,7 +843,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
- event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
+ event_map = await self._get_events_from_cache(
+ member_event_ids, update_metrics=False
+ )
missing_member_event_ids = []
for event_id in member_event_ids:
@@ -922,7 +904,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
- batch_size=500,
+ batch_size=1000,
desc="_get_joined_profiles_from_event_ids",
)
@@ -1017,7 +999,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
async def get_joined_hosts(
- self, room_id: str, state_entry: "_StateCacheEntry"
+ self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
@@ -1030,7 +1012,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
- room_id, state_group, state_entry=state_entry
+ room_id, state_group, state, state_entry=state_entry
)
@cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1020,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
room_id: str,
state_group: Union[object, int],
+ state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1076,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state(
- room_id, state_entry
+ room_id, state, state_entry
)
cache.hosts_to_joined_users = {}
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3a1df7776c..2590b52f73 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1022,8 +1022,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
async def get_all_new_events_stream(
- self, from_id: int, current_id: int, limit: int
- ) -> Tuple[int, List[EventBase]]:
+ self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+ ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
@@ -1032,19 +1032,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
+ get_prev_content: whether to fetch previous event content
Returns:
- A tuple of (next_id, events), where `next_id` is the next value to
- pass as `from_id` (it will either be the stream_ordering of the
- last returned event, or, if fewer than `limit` events were found,
- the `current_id`).
+ A tuple of (next_id, events, event_to_received_ts), where `next_id`
+ is the next value to pass as `from_id` (it will either be the
+ stream_ordering of the last returned event, or, if fewer than `limit`
+ events were found, the `current_id`). The `event_to_received_ts` is
+ a dictionary mapping event ID to the event `received_ts`.
"""
def get_all_new_events_stream_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[str]]:
+ ) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
- "SELECT e.stream_ordering, e.event_id"
+ "SELECT e.stream_ordering, e.event_id, e.received_ts"
" FROM events AS e"
" WHERE"
" ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1059,15 +1061,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if len(rows) == limit:
upper_bound = rows[-1][0]
- return upper_bound, [row[1] for row in rows]
+ event_to_received_ts: Dict[str, Optional[int]] = {
+ row[1]: row[2] for row in rows
+ }
+ return upper_bound, event_to_received_ts
- upper_bound, event_ids = await self.db_pool.runInteraction(
+ upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = await self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(
+ event_to_received_ts.keys(),
+ get_prev_content=get_prev_content,
+ )
- return upper_bound, events
+ return upper_bound, events, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index fa9eadaca7..a7fcc564a9 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
+from synapse.util.caches import intern_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
- key = (typ, state_key)
+ key = (intern_string(typ), intern_string(state_key))
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..afbc85ad0c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -400,14 +400,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
- current_state_ids: StateMap[str],
+ current_state_ids: Optional[StateMap[str]],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
+ At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+ `prev_group` is not None, `delta_ids` must also not be None.
+
Args:
event_id: The event ID for which the state was calculated
room_id
- prev_group: A previous state group for the room, optional.
+ prev_group: A previous state group for the room.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
@@ -418,10 +421,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn: LoggingTransaction) -> int:
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
+ if prev_group is None and current_state_ids is None:
+ raise Exception("current_state_ids and prev_group can't both be None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("delta_ids is None when prev_group is not None")
+
+ def insert_delta_group_txn(
+ txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+ ) -> Optional[int]:
+ """Try and persist the new group as a delta.
+
+ Requires that we have the state as a delta from a previous state group.
+
+ Returns:
+ The state group if successfully created, or None if the state
+ needs to be persisted as a full state.
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ # if the chain of state group deltas is going too long, we fall back to
+ # persisting a complete state group.
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ return None
state_group = self._state_group_seq_gen.get_next_id_txn(txn)
@@ -431,51 +465,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- assert delta_ids is not None
-
- self.db_pool.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
- self.db_pool.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- keys=("state_group", "room_id", "type", "state_key", "event_id"),
- values=[
- (state_group, room_id, key[0], key[1], state_id)
- for key, state_id in delta_ids.items()
- ],
- )
- else:
- self.db_pool.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- keys=("state_group", "room_id", "type", "state_key", "event_id"),
- values=[
- (state_group, room_id, key[0], key[1], state_id)
- for key, state_id in current_state_ids.items()
- ],
- )
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in delta_ids.items()
+ ],
+ )
+
+ return state_group
+
+ def insert_full_state_txn(
+ txn: LoggingTransaction, current_state_ids: StateMap[str]
+ ) -> int:
+ """Persist the full state, returning the new state group."""
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in current_state_ids.items()
+ ],
+ )
# Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
@@ -491,7 +519,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_members_cache.update,
self._state_group_members_cache.sequence,
key=state_group,
- value=dict(current_member_state_ids),
+ value=current_member_state_ids,
)
current_non_member_state_ids = {
@@ -503,13 +531,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
- value=dict(current_non_member_state_ids),
+ value=current_non_member_state_ids,
)
return state_group
+ if prev_group is not None:
+ state_group = await self.db_pool.runInteraction(
+ "store_state_group.insert_delta_group",
+ insert_delta_group_txn,
+ prev_group,
+ delta_ids,
+ )
+ if state_group is not None:
+ return state_group
+
+ # We're going to persist the state as a complete group rather than
+ # a delta, so first we need to ensure we have loaded the state map
+ # from the database.
+ if current_state_ids is None:
+ assert prev_group is not None
+ assert delta_ids is not None
+ groups = await self._get_state_for_groups([prev_group])
+ current_state_ids = dict(groups[prev_group])
+ current_state_ids.update(delta_ids)
+
return await self.db_pool.runInteraction(
- "store_state_group", _store_state_group_txn
+ "store_state_group.insert_full_state",
+ insert_full_state_txn,
+ current_state_ids,
)
async def purge_unreferenced_state_groups(
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index dc237e3032..a9a88c8bfd 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -74,13 +74,14 @@ Changes in SCHEMA_VERSION = 71:
Changes in SCHEMA_VERSION = 72:
- event_edges.(room_id, is_state) are no longer written to.
+ - Tables related to groups are dropped.
"""
SCHEMA_COMPAT_VERSION = (
- # We no longer maintain `event_edges.room_id`, so synapses with SCHEMA_VERSION < 71
- # will break.
- 71
+ # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72
+ # could break.
+ 72
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
diff --git a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
new file mode 100644
index 0000000000..55a5d092cc
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
@@ -0,0 +1,47 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine, *args, **kwargs):
+ """Add a bg update to populate the `state_key` and `rejection_reason` columns of `events`"""
+
+ # we know that any new events will have the columns populated (and that has been
+ # the case since schema_version 68, so there is no chance of rolling back now).
+ #
+ # So, we only need to make sure that existing rows are updated. We read the
+ # current min and max stream orderings, since that is guaranteed to include all
+ # the events that were stored before the new columns were added.
+ cur.execute("SELECT MIN(stream_ordering), MAX(stream_ordering) FROM events")
+ (min_stream_ordering, max_stream_ordering) = cur.fetchone()
+
+ if min_stream_ordering is None:
+ # no rows, nothing to do.
+ return
+
+ cur.execute(
+ "INSERT into background_updates (ordering, update_name, progress_json)"
+ " VALUES (7203, 'events_populate_state_key_rejections', ?)",
+ (
+ json.dumps(
+ {
+ "min_stream_ordering_exclusive": min_stream_ordering - 1,
+ "max_stream_ordering_inclusive": max_stream_ordering,
+ }
+ ),
+ ),
+ )
diff --git a/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
new file mode 100644
index 0000000000..0da668aa3a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- event_reference_hashes is unused, so we can drop it
+DROP TABLE event_reference_hashes;
diff --git a/synapse/storage/schema/main/delta/72/03remove_groups.sql b/synapse/storage/schema/main/delta/72/03remove_groups.sql
new file mode 100644
index 0000000000..b7c5894de8
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03remove_groups.sql
@@ -0,0 +1,31 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Remove the tables which powered the unspecced groups/communities feature.
+DROP TABLE IF EXISTS group_attestations_remote;
+DROP TABLE IF EXISTS group_attestations_renewals;
+DROP TABLE IF EXISTS group_invites;
+DROP TABLE IF EXISTS group_roles;
+DROP TABLE IF EXISTS group_room_categories;
+DROP TABLE IF EXISTS group_rooms;
+DROP TABLE IF EXISTS group_summary_roles;
+DROP TABLE IF EXISTS group_summary_room_categories;
+DROP TABLE IF EXISTS group_summary_rooms;
+DROP TABLE IF EXISTS group_summary_users;
+DROP TABLE IF EXISTS group_users;
+DROP TABLE IF EXISTS groups;
+DROP TABLE IF EXISTS local_group_membership;
+DROP TABLE IF EXISTS local_group_updates;
+DROP TABLE IF EXISTS remote_profile_cache;
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 211437cfaa..466e5137f2 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -166,6 +166,7 @@ class PartialCurrentStateTracker:
logger.info(
"Awaiting un-partial-stating of room %s",
room_id,
+ stack_info=True,
)
await make_deferred_yieldable(d)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 8ed5325c5d..31f41fec82 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -730,3 +730,41 @@ class LruCache(Generic[KT, VT]):
# This happens e.g. in the sync code where we have an expiring cache of
# lru caches.
self.clear()
+
+
+class AsyncLruCache(Generic[KT, VT]):
+ """
+ An asynchronous wrapper around a subset of the LruCache API.
+
+ On its own this doesn't change the behaviour but allows subclasses that
+ utilize external cache systems that require await behaviour to be created.
+ """
+
+ def __init__(self, *args, **kwargs): # type: ignore
+ self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
+
+ async def get(
+ self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+ ) -> Optional[VT]:
+ return self._lru_cache.get(key, update_metrics=update_metrics)
+
+ async def set(self, key: KT, value: VT) -> None:
+ self._lru_cache.set(key, value)
+
+ async def invalidate(self, key: KT) -> None:
+ # This method should invalidate any external cache and then invalidate the LruCache.
+ return self._lru_cache.invalidate(key)
+
+ def invalidate_local(self, key: KT) -> None:
+ """Remove an entry from the local cache
+
+ This variant of `invalidate` is useful if we know that the external
+ cache has already been invalidated.
+ """
+ return self._lru_cache.invalidate(key)
+
+ async def contains(self, key: KT) -> bool:
+ return self._lru_cache.contains(key)
+
+ async def clear(self) -> None:
+ self._lru_cache.clear()
|