diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 19ca399d44..9293808640 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -50,7 +50,7 @@ def request_registration(
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
- r = requests.get(url, verify=False)
+ r = requests.get(url)
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
@@ -88,7 +88,7 @@ def request_registration(
}
_print("Sending registration request...")
- r = requests.post(url, json=data, verify=False)
+ r = requests.post(url, json=data)
if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index ab2b29cf1b..ef8590db65 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -191,7 +191,7 @@ IGNORED_TABLES = {
"user_directory_search_stat",
"user_directory_search_pos",
"users_who_share_private_rooms",
- "users_in_public_room",
+ "users_in_public_rooms",
# UI auth sessions have foreign keys so additional care needs to be taken,
# the sessions are transient anyway, so ignore them.
"ui_auth_sessions",
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 72d30da300..f9e18d2053 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -368,9 +368,14 @@ class ServerConfig(Config):
# Whether to enable user presence.
presence_config = config.get("presence") or {}
- self.use_presence = presence_config.get("enabled")
- if self.use_presence is None:
- self.use_presence = config.get("use_presence", True)
+ presence_enabled = presence_config.get("enabled")
+ if presence_enabled is None:
+ presence_enabled = config.get("use_presence", True)
+
+ # Whether presence is enabled *at all*.
+ self.presence_enabled = bool(presence_enabled)
+ # Whether to internally track presence, requires that presence is enabled,
+ self.track_presence = self.presence_enabled and presence_enabled != "untracked"
# Custom presence router module
# This is the legacy way of configuring it (the config should now be put in the modules section)
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f1766088fc..6d67a8cd5c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -358,9 +358,9 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `account_data` messages."
)
- if len(self.writers.receipts) != 1:
+ if len(self.writers.receipts) == 0:
raise ConfigError(
- "Must only specify one instance to handle `receipts` messages."
+ "Must specify at least one instance to handle `receipts` messages."
)
if len(self.writers.events) == 0:
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 53af423a5a..ac2cf83d9f 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -17,6 +17,7 @@ import re
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -45,6 +46,7 @@ from . import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
+ from synapse.server import HomeServer
# Split strings on "." but not "\." (or "\\\.").
@@ -56,6 +58,13 @@ CANONICALJSON_MAX_INT = (2**53) - 1
CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
+# Module API callback that allows adding fields to the unsigned section of
+# events that are sent to clients.
+ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[
+ [EventBase], Awaitable[JsonDict]
+]
+
+
def prune_event(event: EventBase) -> EventBase:
"""Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
@@ -509,7 +518,13 @@ class EventClientSerializer:
clients.
"""
- def serialize_event(
+ def __init__(self, hs: "HomeServer") -> None:
+ self._store = hs.get_datastores().main
+ self._add_extra_fields_to_unsigned_client_event_callbacks: List[
+ ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+ ] = []
+
+ async def serialize_event(
self,
event: Union[JsonDict, EventBase],
time_now: int,
@@ -535,10 +550,21 @@ class EventClientSerializer:
serialized_event = serialize_event(event, time_now, config=config)
+ new_unsigned = {}
+ for callback in self._add_extra_fields_to_unsigned_client_event_callbacks:
+ u = await callback(event)
+ new_unsigned.update(u)
+
+ if new_unsigned:
+ # We do the `update` this way round so that modules can't clobber
+ # existing fields.
+ new_unsigned.update(serialized_event["unsigned"])
+ serialized_event["unsigned"] = new_unsigned
+
# Check if there are any bundled aggregations to include with the event.
if bundle_aggregations:
if event.event_id in bundle_aggregations:
- self._inject_bundled_aggregations(
+ await self._inject_bundled_aggregations(
event,
time_now,
config,
@@ -548,7 +574,7 @@ class EventClientSerializer:
return serialized_event
- def _inject_bundled_aggregations(
+ async def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
@@ -590,7 +616,7 @@ class EventClientSerializer:
# said that we should only include the `event_id`, `origin_server_ts` and
# `sender` of the edit; however MSC3925 proposes extending it to the whole
# of the edit, which is what we do here.
- serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
+ serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event(
event_aggregations.replace,
time_now,
config=config,
@@ -600,7 +626,7 @@ class EventClientSerializer:
if event_aggregations.thread:
thread = event_aggregations.thread
- serialized_latest_event = self.serialize_event(
+ serialized_latest_event = await self.serialize_event(
thread.latest_event,
time_now,
config=config,
@@ -623,7 +649,7 @@ class EventClientSerializer:
"m.relations", {}
).update(serialized_aggregations)
- def serialize_events(
+ async def serialize_events(
self,
events: Iterable[Union[JsonDict, EventBase]],
time_now: int,
@@ -645,7 +671,7 @@ class EventClientSerializer:
The list of serialized events
"""
return [
- self.serialize_event(
+ await self.serialize_event(
event,
time_now,
config=config,
@@ -654,6 +680,14 @@ class EventClientSerializer:
for event in events
]
+ def register_add_extra_fields_to_unsigned_client_event_callback(
+ self, callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+ ) -> None:
+ """Register a callback that returns additions to the unsigned section of
+ serialized events.
+ """
+ self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
+
_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 356ab0492b..8e3064c7e7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1401,7 +1401,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
- if not self.config.server.use_presence and edu_type == EduTypes.PRESENCE:
+ if not self.config.server.track_presence and edu_type == EduTypes.PRESENCE:
return
# Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 7b6b1da090..7980d1a322 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -844,7 +844,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str])
"""
- if not states or not self.hs.config.server.use_presence:
+ if not states or not self.hs.config.server.track_presence:
# No-op if presence is disabled.
return
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index c200a45f3a..873dadc3bd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -47,6 +47,7 @@ from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
+ MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
StreamKeyType,
@@ -217,7 +218,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: StreamKeyType,
- new_token: Union[int, RoomStreamToken],
+ new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
"""
@@ -259,19 +260,6 @@ class ApplicationServicesHandler:
):
return
- # Assert that new_token is an integer (and not a RoomStreamToken).
- # All of the supported streams that this function handles use an
- # integer to track progress (rather than a RoomStreamToken - a
- # vector clock implementation) as they don't support multiple
- # stream writers.
- #
- # As a result, we simply assert that new_token is an integer.
- # If we do end up needing to pass a RoomStreamToken down here
- # in the future, using RoomStreamToken.stream (the minimum stream
- # position) to convert to an ascending integer value should work.
- # Additional context: https://github.com/matrix-org/synapse/pull/11137
- assert isinstance(new_token, int)
-
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == StreamKeyType.TO_DEVICE
@@ -286,6 +274,9 @@ class ApplicationServicesHandler:
):
return
+ # We know we're not a `RoomStreamToken` at this point.
+ assert not isinstance(new_token, RoomStreamToken)
+
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
@@ -327,7 +318,7 @@ class ApplicationServicesHandler:
self,
services: List[ApplicationService],
stream_key: StreamKeyType,
- new_token: int,
+ new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s", stream_key)
@@ -340,6 +331,7 @@ class ApplicationServicesHandler:
#
# Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off.
+ assert isinstance(new_token, int)
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -350,15 +342,23 @@ class ApplicationServicesHandler:
(service.id, stream_key)
):
if stream_key == StreamKeyType.RECEIPT:
+ assert isinstance(new_token, MultiWriterStreamToken)
+
+ # We store appservice tokens as integers, so we ignore
+ # the `instance_map` components and instead simply
+ # follow the base stream position.
+ new_token = MultiWriterStreamToken(stream=new_token.stream)
+
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
- service, "read_receipt", new_token
+ service, "read_receipt", new_token.stream
)
elif stream_key == StreamKeyType.PRESENCE:
+ assert isinstance(new_token, int)
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
@@ -368,6 +368,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.TO_DEVICE:
+ assert isinstance(new_token, int)
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
@@ -383,6 +384,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.DEVICE_LIST:
+ assert isinstance(new_token, int)
device_list_summary = await self._get_device_list_summary(
service, new_token
)
@@ -432,7 +434,7 @@ class ApplicationServicesHandler:
return typing
async def _handle_receipts(
- self, service: ApplicationService, new_token: int
+ self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
@@ -455,15 +457,17 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
- if new_token is not None and new_token <= from_key:
+ if new_token is not None and new_token.stream <= from_key:
logger.debug(
"Rejecting token lower than or equal to stored: %s" % (new_token,)
)
return []
+ from_token = MultiWriterStreamToken(stream=from_key)
+
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
- service=service, from_key=from_key, to_key=new_token
+ service=service, from_key=from_token, to_key=new_token
)
return receipts
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6a8f8f2fd1..370f4041fb 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -103,10 +103,10 @@ class DeactivateAccountHandler:
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
- for threepid in bound_threepids:
+ for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
- user_id, threepid["medium"], threepid["address"], id_server
+ user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b0f6011629..93472d0117 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -595,6 +595,8 @@ class DeviceHandler(DeviceWorkerHandler):
)
# Delete device messages asynchronously and in batches using the task scheduler
+ # We specify an upper stream id to avoid deleting non delivered messages
+ # if an user re-uses a device ID.
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=device_id,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5a0c1f47be..d06524495f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -665,6 +665,20 @@ class E2eKeysHandler:
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
+ """
+ Args:
+ query: A chain of maps from (user_id, device_id, algorithm) to the requested
+ number of keys to claim.
+ user: The user who is claiming these keys.
+ timeout: How long to wait for any federation key claim requests before
+ giving up.
+ always_include_fallback_keys: always include a fallback key for local users'
+ devices, even if we managed to claim a one-time-key.
+
+ Returns: a heterogeneous dict with two keys:
+ one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
+ failures: map from remote destination to a JsonDict describing the error.
+ """
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
@@ -745,6 +759,16 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ """
+ Args:
+ user_id: user whose keys are being uploaded.
+ device_id: device whose keys are being uploaded.
+ keys: the body of a /keys/upload request.
+
+ Returns a dictionary with one field:
+ "one_time_keys": A mapping from algorithm to number of keys for that
+ algorithm, including those previously persisted.
+ """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d12803bf0f..756825061c 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -120,7 +120,7 @@ class EventStreamHandler:
events.extend(to_add)
- chunks = self._event_serializer.serialize_events(
+ chunks = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 472879c964..c041b67993 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -19,6 +19,8 @@ import logging
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
+import attr
+
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -357,9 +359,9 @@ class IdentityHandler:
# Check to see if a session already exists and that it is not yet
# marked as validated
- if session and session.get("validated_at") is None:
- session_id = session["session_id"]
- last_send_attempt = session["last_send_attempt"]
+ if session and session.validated_at is None:
+ session_id = session.session_id
+ last_send_attempt = session.last_send_attempt
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
@@ -480,7 +482,6 @@ class IdentityHandler:
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
- validation_session = None
# Try to validate as email
if self.hs.config.email.can_verify_email:
@@ -488,19 +489,18 @@ class IdentityHandler:
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
-
- if validation_session:
- return validation_session
+ if validation_session:
+ return attr.asdict(validation_session)
# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
- validation_session = await self.threepid_from_creds(
+ return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)
- return validation_session
+ return None
async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index c34bd7db95..c4bec955fe 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -145,7 +145,7 @@ class InitialSyncHandler:
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms(
joined_rooms,
- to_key=int(now_token.receipt_key),
+ to_key=now_token.receipt_key,
)
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
@@ -173,7 +173,7 @@ class InitialSyncHandler:
d["inviter"] = event.sender
invite_event = await self.store.get_event(event.event_id)
- d["invite"] = self._event_serializer.serialize_event(
+ d["invite"] = await self._event_serializer.serialize_event(
invite_event,
time_now,
config=serializer_options,
@@ -225,7 +225,7 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
- self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
messages,
time_now=time_now,
config=serializer_options,
@@ -235,7 +235,7 @@ class InitialSyncHandler:
"end": await end_token.to_string(self.store),
}
- d["state"] = self._event_serializer.serialize_events(
+ d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
config=serializer_options,
@@ -387,7 +387,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
@@ -396,7 +396,7 @@ class InitialSyncHandler:
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
- self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
room_state.values(), time_now, config=serialize_options
)
),
@@ -420,7 +420,7 @@ class InitialSyncHandler:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
# Don't bundle aggregations as this is a deprecated API.
- state = self._event_serializer.serialize_events(
+ state = await self._event_serializer.serialize_events(
current_state.values(),
time_now,
config=serialize_options,
@@ -439,7 +439,7 @@ class InitialSyncHandler:
async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list
- if not self.hs.config.server.use_presence:
+ if not self.hs.config.server.presence_enabled:
return []
states = await presence_handler.get_states(
@@ -497,7 +497,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 41a35ce510..811a41f161 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -244,7 +244,7 @@ class MessageHandler:
)
room_state = room_state_events[membership_event_id]
- events = self._event_serializer.serialize_events(
+ events = await self._event_serializer.serialize_events(
room_state.values(),
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
@@ -999,7 +999,26 @@ class EventCreationHandler:
raise ShadowBanError()
if ratelimit:
- await self.request_ratelimiter.ratelimit(requester, update=False)
+ room_id = event_dict["room_id"]
+ try:
+ room_version = await self.store.get_room_version(room_id)
+ except NotFoundError:
+ # The room doesn't exist.
+ raise AuthError(403, f"User {requester.user} not in room {room_id}")
+
+ if room_version.updated_redaction_rules:
+ redacts = event_dict["content"].get("redacts")
+ else:
+ redacts = event_dict.get("redacts")
+
+ is_admin_redaction = await self.is_admin_redaction(
+ event_type=event_dict["type"],
+ sender=event_dict["sender"],
+ redacts=redacts,
+ )
+ await self.request_ratelimiter.ratelimit(
+ requester, is_admin_redaction=is_admin_redaction, update=False
+ )
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
@@ -1508,6 +1527,18 @@ class EventCreationHandler:
first_event.room_id
)
if writer_instance != self._instance_name:
+ # Ratelimit before sending to the other event persister, to
+ # ensure that we correctly have ratelimits on both the event
+ # creators and event persisters.
+ if ratelimit:
+ for event, _ in events_and_context:
+ is_admin_redaction = await self.is_admin_redaction(
+ event.type, event.sender, event.redacts
+ )
+ await self.request_ratelimiter.ratelimit(
+ requester, is_admin_redaction=is_admin_redaction
+ )
+
try:
result = await self.send_events(
instance_name=writer_instance,
@@ -1538,6 +1569,7 @@ class EventCreationHandler:
# stream_ordering entry manually (as it was persisted on
# another worker).
event.internal_metadata.stream_ordering = stream_id
+
return event
event = await self.persist_and_notify_client_events(
@@ -1696,21 +1728,9 @@ class EventCreationHandler:
# can apply different ratelimiting. We do this by simply checking
# it's not a self-redaction (to avoid having to look up whether the
# user is actually admin or not).
- is_admin_redaction = False
- if event.type == EventTypes.Redaction:
- assert event.redacts is not None
-
- original_event = await self.store.get_event(
- event.redacts,
- redact_behaviour=EventRedactBehaviour.as_is,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=True,
- )
-
- is_admin_redaction = bool(
- original_event and event.sender != original_event.sender
- )
+ is_admin_redaction = await self.is_admin_redaction(
+ event.type, event.sender, event.redacts
+ )
await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction
@@ -1930,6 +1950,27 @@ class EventCreationHandler:
return persisted_events[-1]
+ async def is_admin_redaction(
+ self, event_type: str, sender: str, redacts: Optional[str]
+ ) -> bool:
+ """Return whether the event is a redaction made by an admin, and thus
+ should use a different ratelimiter.
+ """
+ if event_type != EventTypes.Redaction:
+ return False
+
+ assert redacts is not None
+
+ original_event = await self.store.get_event(
+ redacts,
+ redact_behaviour=EventRedactBehaviour.as_is,
+ get_prev_content=False,
+ allow_rejected=False,
+ allow_none=True,
+ )
+
+ return bool(original_event and sender != original_event.sender)
+
async def _maybe_kick_guest_users(
self, event: EventBase, context: EventContext
) -> None:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 878f267a4e..87e51bca48 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -657,7 +657,7 @@ class PaginationHandler:
chunk = {
"chunk": (
- self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
events,
time_now,
config=serialize_options,
@@ -669,7 +669,7 @@ class PaginationHandler:
}
if state:
- chunk["state"] = self._event_serializer.serialize_events(
+ chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, config=serialize_options
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index dfc0b9db07..202beee738 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -192,7 +192,8 @@ class BasePresenceHandler(abc.ABC):
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
- self._presence_enabled = hs.config.server.use_presence
+ self._presence_enabled = hs.config.server.presence_enabled
+ self._track_presence = hs.config.server.track_presence
self._federation = None
if hs.should_send_federation():
@@ -512,7 +513,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
)
async def _on_shutdown(self) -> None:
- if self._presence_enabled:
+ if self._track_presence:
self.hs.get_replication_command_handler().send_command(
ClearUserSyncsCommand(self.instance_id)
)
@@ -524,7 +525,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
is_syncing: bool,
last_sync_ms: int,
) -> None:
- if self._presence_enabled:
+ if self._track_presence:
self.hs.get_replication_command_handler().send_user_sync(
self.instance_id, user_id, device_id, is_syncing, last_sync_ms
)
@@ -571,7 +572,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
Called by the sync and events servlets to record that a user has connected to
this worker and is waiting for some events.
"""
- if not affect_presence or not self._presence_enabled:
+ if not affect_presence or not self._track_presence:
return _NullContextManager()
# Note that this causes last_active_ts to be incremented which is not
@@ -702,8 +703,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string()
- # If presence is disabled, no-op
- if not self._presence_enabled:
+ # If tracking of presence is disabled, no-op
+ if not self._track_presence:
return
# Proxy request to instance that writes presence
@@ -723,7 +724,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self._presence_enabled:
+ if not self._track_presence:
return
# Proxy request to instance that writes presence
@@ -760,7 +761,7 @@ class PresenceHandler(BasePresenceHandler):
] = {}
now = self.clock.time_msec()
- if self._presence_enabled:
+ if self._track_presence:
for state in self.user_to_current_state.values():
# Create a psuedo-device to properly handle time outs. This will
# be overridden by any "real" devices within SYNC_ONLINE_TIMEOUT.
@@ -831,7 +832,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
- if self._presence_enabled:
+ if self._track_presence:
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
@@ -839,6 +840,9 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, self._handle_timeouts, 5000
)
+ # Presence information is persisted, whether or not it is being tracked
+ # internally.
+ if self._presence_enabled:
self.clock.call_later(
60,
self.clock.looping_call,
@@ -854,7 +858,7 @@ class PresenceHandler(BasePresenceHandler):
)
# Used to handle sending of presence to newly joined users/servers
- if self._presence_enabled:
+ if self._track_presence:
self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always
@@ -905,7 +909,9 @@ class PresenceHandler(BasePresenceHandler):
)
async def _update_states(
- self, new_states: Iterable[UserPresenceState], force_notify: bool = False
+ self,
+ new_states: Iterable[UserPresenceState],
+ force_notify: bool = False,
) -> None:
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -943,7 +949,7 @@ class PresenceHandler(BasePresenceHandler):
for new_state in new_states:
user_id = new_state.user_id
- # Its fine to not hit the database here, as the only thing not in
+ # It's fine to not hit the database here, as the only thing not in
# the current state cache are OFFLINE states, where the only field
# of interest is last_active which is safe enough to assume is 0
# here.
@@ -957,6 +963,9 @@ class PresenceHandler(BasePresenceHandler):
is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
now=now,
+ # When overriding disabled presence, don't kick off all the
+ # wheel timers.
+ persist=not self._track_presence,
)
if force_notify:
@@ -1072,7 +1081,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self._presence_enabled:
+ if not self._track_presence:
return
user_id = user.to_string()
@@ -1124,7 +1133,7 @@ class PresenceHandler(BasePresenceHandler):
client that is being used by a user.
presence_state: The presence state indicated in the sync request
"""
- if not affect_presence or not self._presence_enabled:
+ if not affect_presence or not self._track_presence:
return _NullContextManager()
curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0)
@@ -1284,7 +1293,7 @@ class PresenceHandler(BasePresenceHandler):
async def incoming_presence(self, origin: str, content: JsonDict) -> None:
"""Called when we receive a `m.presence` EDU from a remote server."""
- if not self._presence_enabled:
+ if not self._track_presence:
return
now = self.clock.time_msec()
@@ -1359,7 +1368,7 @@ class PresenceHandler(BasePresenceHandler):
raise SynapseError(400, "Invalid presence state")
# If presence is disabled, no-op
- if not self._presence_enabled:
+ if not self._track_presence:
return
user_id = target_user.to_string()
@@ -2118,6 +2127,7 @@ def handle_update(
is_mine: bool,
wheel_timer: WheelTimer,
now: int,
+ persist: bool,
) -> Tuple[UserPresenceState, bool, bool]:
"""Given a presence update:
1. Add any appropriate timers.
@@ -2129,6 +2139,8 @@ def handle_update(
is_mine: Whether the user is ours
wheel_timer
now: Time now in ms
+ persist: True if this state should persist until another update occurs.
+ Skips insertion into wheel timers.
Returns:
3-tuple: `(new_state, persist_and_notify, federation_ping)` where:
@@ -2146,14 +2158,15 @@ def handle_update(
if is_mine:
if new_state.state == PresenceState.ONLINE:
# Idle timer
- wheel_timer.insert(
- now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
- )
+ if not persist:
+ wheel_timer.insert(
+ now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
+ )
active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY
new_state = new_state.copy_and_replace(currently_active=active)
- if active:
+ if active and not persist:
wheel_timer.insert(
now=now,
obj=user_id,
@@ -2162,11 +2175,12 @@ def handle_update(
if new_state.state != PresenceState.OFFLINE:
# User has stopped syncing
- wheel_timer.insert(
- now=now,
- obj=user_id,
- then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
- )
+ if not persist:
+ wheel_timer.insert(
+ now=now,
+ obj=user_id,
+ then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
+ )
last_federate = new_state.last_federation_update_ts
if now - last_federate > FEDERATION_PING_INTERVAL:
@@ -2174,7 +2188,7 @@ def handle_update(
new_state = new_state.copy_and_replace(last_federation_update_ts=now)
federation_ping = True
- if new_state.state == PresenceState.BUSY:
+ if new_state.state == PresenceState.BUSY and not persist:
wheel_timer.insert(
now=now,
obj=user_id,
@@ -2182,11 +2196,13 @@ def handle_update(
)
else:
- wheel_timer.insert(
- now=now,
- obj=user_id,
- then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
- )
+ # An update for a remote user was received.
+ if not persist:
+ wheel_timer.insert(
+ now=now,
+ obj=user_id,
+ then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
+ )
# Check whether the change was something worth notifying about
if should_notify(prev_state, new_state, is_mine):
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 69ac468f75..b5f7a8b47e 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -20,6 +20,7 @@ from synapse.streams import EventSource
from synapse.types import (
JsonDict,
JsonMapping,
+ MultiWriterStreamToken,
ReadReceipt,
StreamKeyType,
UserID,
@@ -200,7 +201,7 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt)
-class ReceiptEventSource(EventSource[int, JsonMapping]):
+class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.config = hs.config
@@ -273,13 +274,12 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
async def get_new_events(
self,
user: UserID,
- from_key: int,
+ from_key: MultiWriterStreamToken,
limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
- ) -> Tuple[List[JsonMapping], int]:
- from_key = int(from_key)
+ ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
to_key = self.get_current_key()
if from_key == to_key:
@@ -296,8 +296,11 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key
async def get_new_events_as(
- self, from_key: int, to_key: int, service: ApplicationService
- ) -> Tuple[List[JsonMapping], int]:
+ self,
+ from_key: MultiWriterStreamToken,
+ to_key: MultiWriterStreamToken,
+ service: ApplicationService,
+ ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
"""Returns a set of new read receipt events that an appservice
may be interested in.
@@ -312,8 +315,6 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
appservice may be interested in.
* The current read receipt stream token.
"""
- from_key = int(from_key)
-
if from_key == to_key:
return [], to_key
@@ -333,5 +334,5 @@ class ReceiptEventSource(EventSource[int, JsonMapping]):
return events, to_key
- def get_current_key(self) -> int:
+ def get_current_key(self) -> MultiWriterStreamToken:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 9b13448cdd..a15983afae 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -167,7 +167,7 @@ class RelationsHandler:
now = self._clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = {
- "chunk": self._event_serializer.serialize_events(
+ "chunk": await self._event_serializer.serialize_events(
events,
now,
bundle_aggregations=aggregations,
@@ -177,7 +177,9 @@ class RelationsHandler:
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
- return_value["original_event"] = self._event_serializer.serialize_event(
+ return_value[
+ "original_event"
+ ] = await self._event_serializer.serialize_event(
event,
now,
bundle_aggregations=None,
@@ -602,7 +604,7 @@ class RelationsHandler:
)
now = self._clock.time_msec()
- serialized_events = self._event_serializer.serialize_events(
+ serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 97c9f01245..6d680b0795 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1939,9 +1939,10 @@ class RoomShutdownHandler:
else:
logger.info("Shutting down room %r", room_id)
- users = await self.store.get_users_in_room(room_id)
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
+ users = await self.store.get_local_users_related_to_room(room_id)
+ for user_id, membership in users:
+ # If the user is not in the room (or is banned), nothing to do.
+ if membership not in (Membership.JOIN, Membership.INVITE, Membership.KNOCK):
continue
logger.info("Kicking %r from %r...", user_id, room_id)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aad4706f14..f51ed9d5bb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -374,13 +374,13 @@ class SearchHandler:
serialize_options = SerializeEventConfig(requester=requester)
for context in contexts.values():
- context["events_before"] = self._event_serializer.serialize_events(
+ context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"],
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
)
- context["events_after"] = self._event_serializer.serialize_events(
+ context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"],
time_now,
bundle_aggregations=aggregations,
@@ -390,7 +390,7 @@ class SearchHandler:
results = [
{
"rank": search_result.rank_map[e.event_id],
- "result": self._event_serializer.serialize_event(
+ "result": await self._event_serializer.serialize_event(
e,
time_now,
bundle_aggregations=aggregations,
@@ -409,7 +409,7 @@ class SearchHandler:
if state_results:
rooms_cat_res["state"] = {
- room_id: self._event_serializer.serialize_events(
+ room_id: await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
for room_id, state_events in state_results.items()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e9a544e754..62f2454f5d 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -1206,10 +1206,7 @@ class SsoHandler:
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
- for device in devices:
- user_id = device["user_id"]
- device_id = device["device_id"]
-
+ for user_id, device_id in devices:
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 60b4d95cd7..2f1bc5a015 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -57,6 +57,7 @@ from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
+ MultiWriterStreamToken,
MutableStateMap,
Requester,
RoomStreamToken,
@@ -477,7 +478,11 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
- receipt_key = since_token.receipt_key if since_token else 0
+ receipt_key = (
+ since_token.receipt_key
+ if since_token
+ else MultiWriterStreamToken(stream=0)
+ )
receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
@@ -500,12 +505,27 @@ class SyncHandler:
async def _load_filtered_recents(
self,
room_id: str,
+ sync_result_builder: "SyncResultBuilder",
sync_config: SyncConfig,
- now_token: StreamToken,
+ upto_token: StreamToken,
since_token: Optional[StreamToken] = None,
potential_recents: Optional[List[EventBase]] = None,
newly_joined_room: bool = False,
) -> TimelineBatch:
+ """Create a timeline batch for the room
+
+ Args:
+ room_id
+ sync_result_builder
+ sync_config
+ upto_token: The token up to which we should fetch (more) events.
+ If `potential_results` is non-empty then this is *start* of
+ the the list.
+ since_token
+ potential_recents: If non-empty, the events between the since token
+ and current token to send down to clients.
+ newly_joined_room
+ """
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = (
@@ -521,6 +541,20 @@ class SyncHandler:
else:
limited = False
+ # Check if there is a gap, if so we need to mark this as limited and
+ # recalculate which events to send down.
+ gap_token = await self.store.get_timeline_gaps(
+ room_id,
+ since_token.room_key if since_token else None,
+ sync_result_builder.now_token.room_key,
+ )
+ if gap_token:
+ # There's a gap, so we need to ignore the passed in
+ # `potential_recents`, and reset `upto_token` to match.
+ potential_recents = None
+ upto_token = sync_result_builder.now_token
+ limited = True
+
log_kv({"limited": limited})
if potential_recents:
@@ -559,10 +593,10 @@ class SyncHandler:
recents = []
if not limited or block_all_timeline:
- prev_batch_token = now_token
+ prev_batch_token = upto_token
if recents:
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace(
+ prev_batch_token = upto_token.copy_and_replace(
StreamKeyType.ROOM, room_key
)
@@ -573,11 +607,15 @@ class SyncHandler:
filtering_factor = 2
load_limit = max(timeline_limit * filtering_factor, 10)
max_repeat = 5 # Only try a few times per room, otherwise
- room_key = now_token.room_key
+ room_key = upto_token.room_key
end_key = room_key
since_key = None
- if since_token and not newly_joined_room:
+ if since_token and gap_token:
+ # If there is a gap then we need to only include events after
+ # it.
+ since_key = gap_token
+ elif since_token and not newly_joined_room:
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
@@ -647,7 +685,7 @@ class SyncHandler:
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key)
+ prev_batch_token = upto_token.copy_and_replace(StreamKeyType.ROOM, room_key)
# Don't bother to bundle aggregations if the timeline is unlimited,
# as clients will have all the necessary information.
@@ -662,7 +700,9 @@ class SyncHandler:
return TimelineBatch(
events=recents,
prev_batch=prev_batch_token,
- limited=limited or newly_joined_room,
+ # Also mark as limited if this is a new room or there has been a gap
+ # (to force client to paginate the gap).
+ limited=limited or newly_joined_room or gap_token is not None,
bundled_aggregations=bundled_aggregations,
)
@@ -1477,7 +1517,7 @@ class SyncHandler:
# Presence data is included if the server has it enabled and not filtered out.
include_presence_data = bool(
- self.hs_config.server.use_presence
+ self.hs_config.server.presence_enabled
and not sync_config.filter_collection.blocks_all_presence()
)
# Device list updates are sent if a since token is provided.
@@ -2397,8 +2437,9 @@ class SyncHandler:
batch = await self._load_filtered_recents(
room_id,
+ sync_result_builder,
sync_config,
- now_token=upto_token,
+ upto_token=upto_token,
since_token=since_token,
potential_recents=events,
newly_joined_room=newly_joined,
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 78a75bfed6..ab8f7610e9 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
if row:
threepid = {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
+ "medium": row.medium,
+ "address": row.address,
+ "validated_at": row.validated_at,
}
# Valid threepid returned, delete from the db
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 636efc33e8..59b914b87e 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -59,7 +59,7 @@ class BasicProxyCredentials(ProxyCredentials):
a Proxy-Authorization header.
"""
# Encode as base64 and prepend the authorization type
- return b"Basic " + base64.encodebytes(self.username_password)
+ return b"Basic " + base64.b64encode(self.username_password)
@attr.s(auto_attribs=True)
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 7fd46901f7..72b0f1c5de 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -949,10 +949,7 @@ class MediaRepository:
deleted = 0
- for media in old_media:
- origin = media["media_origin"]
- media_id = media["media_id"]
- file_id = media["filesystem_id"]
+ for origin, media_id, file_id in old_media:
key = (origin, media_id)
logger.info("Deleting: %r", key)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 0786d20635..755c59274c 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -23,6 +23,7 @@ from typing import (
Generator,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
TypeVar,
@@ -39,6 +40,7 @@ from twisted.web.resource import Resource
from synapse.api import errors
from synapse.api.errors import SynapseError
+from synapse.api.presence import UserPresenceState
from synapse.config import ConfigError
from synapse.events import EventBase
from synapse.events.presence_router import (
@@ -46,6 +48,7 @@ from synapse.events.presence_router import (
GET_USERS_FOR_STATES_CALLBACK,
PresenceRouter,
)
+from synapse.events.utils import ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
@@ -257,6 +260,7 @@ class ModuleApi:
self.custom_template_dir = hs.config.server.custom_template_directory
self._callbacks = hs.get_module_api_callbacks()
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
+ self._event_serializer = hs.get_event_client_serializer()
try:
app_name = self._hs.config.email.email_app_name
@@ -488,6 +492,25 @@ class ModuleApi:
"""
self._hs.register_module_web_resource(path, resource)
+ def register_add_extra_fields_to_unsigned_client_event_callbacks(
+ self,
+ *,
+ add_field_to_unsigned_callback: Optional[
+ ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+ ] = None,
+ ) -> None:
+ """Registers a callback that can be used to add fields to the unsigned
+ section of events.
+
+ The callback is called every time an event is sent down to a client.
+
+ Added in Synapse 1.96.0
+ """
+ if add_field_to_unsigned_callback is not None:
+ self._event_serializer.register_add_extra_fields_to_unsigned_client_event_callback(
+ add_field_to_unsigned_callback
+ )
+
#########################################################################
# The following methods can be called by the module at any point in time.
@@ -1184,6 +1207,37 @@ class ModuleApi:
presence_events, [destination]
)
+ async def set_presence_for_users(
+ self, users: Mapping[str, Tuple[str, Optional[str]]]
+ ) -> None:
+ """
+ Update the internal presence state of users.
+
+ This can be used for either local or remote users.
+
+ Note that this method can only be run on the process that is configured to write to the
+ presence stream. By default, this is the main process.
+
+ Added in Synapse v1.96.0.
+ """
+
+ # We pull out the presence handler here to break a cyclic
+ # dependency between the presence router and module API.
+ presence_handler = self._hs.get_presence_handler()
+
+ from synapse.handlers.presence import PresenceHandler
+
+ assert isinstance(presence_handler, PresenceHandler)
+
+ states = await presence_handler.current_state_for_users(users.keys())
+ for user_id, (state, status_msg) in users.items():
+ prev_state = states.setdefault(user_id, UserPresenceState.default(user_id))
+ states[user_id] = prev_state.copy_and_replace(
+ state=state, status_msg=status_msg
+ )
+
+ await presence_handler._update_states(states.values(), force_notify=True)
+
def looping_background_call(
self,
f: Callable,
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 99e7715896..ee0bd84f1e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -21,11 +21,13 @@ from typing import (
Dict,
Iterable,
List,
+ Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
+ overload,
)
import attr
@@ -44,6 +46,7 @@ from synapse.metrics import LaterGauge
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
+ MultiWriterStreamToken,
PersistedEventPosition,
RoomStreamToken,
StrCollection,
@@ -127,7 +130,7 @@ class _NotifierUserStream:
def notify(
self,
stream_key: StreamKeyType,
- stream_id: Union[int, RoomStreamToken],
+ stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken],
time_now_ms: int,
) -> None:
"""Notify any listeners for this user of a new event from an
@@ -452,10 +455,48 @@ class Notifier:
except Exception:
logger.exception("Error pusher pool of event")
+ @overload
+ def on_new_event(
+ self,
+ stream_key: Literal[StreamKeyType.ROOM],
+ new_token: RoomStreamToken,
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[StrCollection] = None,
+ ) -> None:
+ ...
+
+ @overload
+ def on_new_event(
+ self,
+ stream_key: Literal[StreamKeyType.RECEIPT],
+ new_token: MultiWriterStreamToken,
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[StrCollection] = None,
+ ) -> None:
+ ...
+
+ @overload
+ def on_new_event(
+ self,
+ stream_key: Literal[
+ StreamKeyType.ACCOUNT_DATA,
+ StreamKeyType.DEVICE_LIST,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.PUSH_RULES,
+ StreamKeyType.TO_DEVICE,
+ StreamKeyType.TYPING,
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS,
+ ],
+ new_token: int,
+ users: Optional[Collection[Union[str, UserID]]] = None,
+ rooms: Optional[StrCollection] = None,
+ ) -> None:
+ ...
+
def on_new_event(
self,
stream_key: StreamKeyType,
- new_token: Union[int, RoomStreamToken],
+ new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 63cf24a14d..38701aea72 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data[_STREAM_POSITION_KEY] = {
"streams": {
- stream.NAME: stream.current_token(local_instance_name)
+ stream.NAME: stream.minimal_local_current_token()
for stream in streams
},
"instance_name": local_instance_name,
@@ -433,7 +433,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if self.WAIT_FOR_STREAMS:
response[_STREAM_POSITION_KEY] = {
- stream.NAME: stream.current_token(self._instance_name)
+ stream.NAME: stream.minimal_local_current_token()
for stream in self._streams
}
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d5337fe588..1312b6f21e 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -126,8 +126,9 @@ class ReplicationDataHandler:
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
elif stream_name == ReceiptsStream.NAME:
+ new_token = self.store.get_max_receipt_stream_id()
self.notifier.on_new_event(
- StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
+ StreamKeyType.RECEIPT, new_token, rooms=[row.room_id for row in rows]
)
await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
@@ -279,14 +280,6 @@ class ReplicationDataHandler:
# may be streaming.
self.notifier.notify_replication()
- def on_remote_server_up(self, server: str) -> None:
- """Called when get a new REMOTE_SERVER_UP command."""
-
- # Let's wake up the transaction queue for the server in case we have
- # pending stuff to send to it.
- if self.send_handler:
- self.send_handler.wake_destination(server)
-
async def wait_for_stream_position(
self,
instance_name: str,
@@ -405,9 +398,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
- def wake_destination(self, server: str) -> None:
- self.federation_sender.wake_destination(server)
-
async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b668bb5da1..afd03137f0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -611,10 +611,14 @@ class ReplicationCommandHandler:
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
- # If the position token matches our current token then we're up to
- # date and there's nothing to do. Otherwise, fetch all updates
- # between then and now.
- missing_updates = cmd.prev_token != current_token
+ # If the incoming previous position is less than our current position
+ # then we're up to date and there's nothing to do. Otherwise, fetch
+ # all updates between then and now.
+ #
+ # Note: We also have to check that `current_token` is at most the
+ # new position, to handle the case where the stream gets "reset"
+ # (e.g. for `caches` and `typing` after the writer's restart).
+ missing_updates = not (cmd.prev_token <= current_token <= cmd.new_token)
while missing_updates:
# Note: There may very well not be any new updates, but we check to
# make sure. This can particularly happen for the event stream where
@@ -644,7 +648,7 @@ class ReplicationCommandHandler:
[stream.parse_row(row) for row in rows],
)
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
+ logger.info("Caught up with stream '%s' to %i", stream_name, current_token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
@@ -657,8 +661,6 @@ class ReplicationCommandHandler:
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
- self._replication_data_handler.on_remote_server_up(cmd.data)
-
self._notifier.notify_remote_server_up(cmd.data)
def on_LOCK_RELEASED(
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 1d9a29d22e..38abb5df54 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -27,7 +27,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
-from synapse.replication.tcp.streams._base import StreamRow, Token
+from synapse.replication.tcp.streams._base import CachesStream, StreamRow, Token
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -204,6 +204,23 @@ class ReplicationStreamer:
# The token has advanced but there is no data to
# send, so we send a `POSITION` to inform other
# workers of the updated position.
+ #
+ # There are two reasons for this: 1) this instance
+ # requested a stream ID but didn't use it, or 2)
+ # this instance advanced its own stream position due
+ # to receiving notifications about other instances
+ # advancing their stream position.
+
+ # We skip sending `POSITION` for the `caches` stream
+ # for the second case as a) it generates a lot of
+ # traffic as every worker would echo each write, and
+ # b) nothing cares if a given worker's caches stream
+ # position lags.
+ if stream.NAME == CachesStream.NAME:
+ # If there haven't been any writes since the
+ # `last_token` then we're in the second case.
+ if stream.minimal_local_current_token() <= last_token:
+ continue
# Note: `last_token` may not *actually* be the
# last token we sent out in a RDATA or POSITION.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c6088a0f99..58a44029aa 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.util.id_generators import AbstractStreamIdGenerator
logger = logging.getLogger(__name__)
@@ -107,22 +108,10 @@ class Stream:
def __init__(
self,
local_instance_name: str,
- current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
- `current_token_function` and `update_function` are callbacks which
- should be implemented by subclasses.
-
- `current_token_function` takes an instance name, which is a writer to
- the stream, and returns the position in the stream of the writer (as
- viewed from the current process). On the writer process this is where
- the writer has successfully written up to, whereas on other processes
- this is the position which we have received updates up to over
- replication. (Note that most streams have a single writer and so their
- implementations ignore the instance name passed in).
-
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
@@ -133,12 +122,28 @@ class Stream:
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
- self.current_token = current_token_function
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)
+ def current_token(self, instance_name: str) -> Token:
+ """This takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process).
+ """
+ # We can't make this an abstract class as it makes mypy unhappy.
+ raise NotImplementedError()
+
+ def minimal_local_current_token(self) -> Token:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+ raise NotImplementedError()
+
def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
@@ -156,6 +161,14 @@ class Stream:
and `limited` is whether there are more updates to fetch.
"""
current_token = self.current_token(self.local_instance_name)
+
+ # If the minimum current token for the local instance is less than or
+ # equal to the last thing we published, we know that there are no
+ # updates.
+ if self.last_token >= self.minimal_local_current_token():
+ self.last_token = current_token
+ return [], current_token, False
+
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@@ -190,6 +203,25 @@ class Stream:
return updates, upto_token, limited
+class _StreamFromIdGen(Stream):
+ """Helper class for simple streams that use a stream ID generator"""
+
+ def __init__(
+ self,
+ local_instance_name: str,
+ update_function: UpdateFunction,
+ stream_id_gen: "AbstractStreamIdGenerator",
+ ):
+ self._stream_id_gen = stream_id_gen
+ super().__init__(local_instance_name, update_function)
+
+ def current_token(self, instance_name: str) -> Token:
+ return self._stream_id_gen.get_current_token_for_writer(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self._stream_id_gen.get_minimal_local_current_token()
+
+
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
@@ -242,17 +274,21 @@ class BackfillStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- self._current_token,
self.store.get_all_new_backfill_event_rows,
)
- def _current_token(self, instance_name: str) -> int:
+ def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
+ def minimal_local_current_token(self) -> Token:
+ # The backfill stream over replication operates on *positive* numbers,
+ # which means we need to negate it.
+ return -self.store._backfill_id_gen.get_minimal_local_current_token()
+
-class PresenceStream(Stream):
+class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
@@ -283,9 +319,7 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_current_presence_token),
- update_function,
+ hs.get_instance_name(), update_function, store._presence_id_gen
)
@@ -305,13 +339,18 @@ class PresenceFederationStream(Stream):
ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"):
- federation_queue = hs.get_presence_handler().get_federation_queue()
+ self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__(
hs.get_instance_name(),
- federation_queue.get_current_token,
- federation_queue.get_replication_rows,
+ self._federation_queue.get_replication_rows,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self._federation_queue.get_current_token(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self._federation_queue.get_current_token(self.local_instance_name)
+
class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -341,20 +380,25 @@ class TypingStream(Stream):
update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates
- current_token_function = typing_writer_handler.get_current_token
+ self.current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
- current_token_function = hs.get_typing_handler().get_current_token
+ self.current_token_function = hs.get_typing_handler().get_current_token
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(current_token_function),
update_function,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self.current_token_function()
-class ReceiptsStream(Stream):
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token_function()
+
+
+class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
@@ -371,12 +415,12 @@ class ReceiptsStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts,
+ store._receipts_id_gen,
)
-class PushRulesStream(Stream):
+class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -387,20 +431,16 @@ class PushRulesStream(Stream):
ROW_TYPE = PushRulesStreamRow
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
+ store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- self._current_token,
- self.store.get_all_push_rule_updates,
+ store.get_all_push_rule_updates,
+ store._push_rules_stream_id_gen,
)
- def _current_token(self, instance_name: str) -> int:
- push_rules_token = self.store.get_max_push_rules_stream_id()
- return push_rules_token
-
-class PushersStream(Stream):
+class PushersStream(_StreamFromIdGen):
"""A user has added/changed/removed a pusher"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -418,8 +458,8 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows,
+ store._pushers_id_gen,
)
@@ -447,15 +487,22 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
+ self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token_for_writer,
- store.get_all_updated_caches,
+ self.store.get_all_updated_caches,
)
+ def current_token(self, instance_name: str) -> Token:
+ return self.store.get_cache_stream_token_for_writer(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ if self.store._cache_id_gen:
+ return self.store._cache_id_gen.get_minimal_local_current_token()
+ return self.current_token(self.local_instance_name)
+
-class DeviceListsStream(Stream):
+class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be
told about a device update.
"""
@@ -473,8 +520,8 @@ class DeviceListsStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self.store.get_device_stream_token),
self._update_function,
+ self.store._device_list_id_gen,
)
async def _update_function(
@@ -525,7 +572,7 @@ class DeviceListsStream(Stream):
return updates, upper_limit_token, devices_limited or signatures_limited
-class ToDeviceStream(Stream):
+class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -539,12 +586,12 @@ class ToDeviceStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages,
+ store._device_inbox_id_gen,
)
-class AccountDataStream(Stream):
+class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -560,8 +607,8 @@ class AccountDataStream(Stream):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function,
+ self.store._account_data_id_gen,
)
async def _update_function(
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ad9b760713..57138fea80 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
+from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
from synapse.replication.tcp.streams._base import (
- Stream,
StreamRow,
StreamUpdateResult,
Token,
+ _StreamFromIdGen,
)
if TYPE_CHECKING:
@@ -51,8 +52,19 @@ data part are:
* The state_key of the state which has changed
* The event id of the new state
+A "state-all" row is sent whenever the "current state" in a room changes, but there are
+too many state updates for a particular room in the same update. This replaces any
+"state" rows on a per-room basis. The fields in the data part are:
+
+* The room id for the state changes
+
"""
+# Any room with more than _MAX_STATE_UPDATES_PER_ROOM will send a EventsStreamAllStateRow
+# instead of individual EventsStreamEventRow. This is predominantly useful when
+# purging large rooms.
+_MAX_STATE_UPDATES_PER_ROOM = 150
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamRow:
@@ -111,15 +123,23 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id: Optional[str]
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventsStreamAllStateRow(BaseEventsStreamRow):
+ TypeId = "state-all"
+
+ room_id: str
+
+
_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow,
EventsStreamCurrentStateRow,
+ EventsStreamAllStateRow,
)
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
-class EventsStream(Stream):
+class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not"""
NAME = "events"
@@ -127,9 +147,7 @@ class EventsStream(Stream):
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
super().__init__(
- hs.get_instance_name(),
- self._store._stream_id_gen.get_current_token_for_writer,
- self._update_function,
+ hs.get_instance_name(), self._update_function, self._store._stream_id_gen
)
async def _update_function(
@@ -139,6 +157,12 @@ class EventsStream(Stream):
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:
+ # The events stream cannot be "reset", so its safe to return early if
+ # the from token is larger than the current token (the DB query will
+ # trivially return 0 rows anyway).
+ if from_token >= current_token:
+ return [], current_token, False
+
# the events stream merges together three separate sources:
# * new events
# * current_state changes
@@ -213,9 +237,28 @@ class EventsStream(Stream):
if stream_id <= upper_limit
)
+ # Separate out rooms that have many state updates, listeners should clear
+ # all state for those rooms.
+ state_updates_by_room = defaultdict(list)
+ for stream_id, room_id, _type, _state_key, _event_id in state_rows:
+ state_updates_by_room[room_id].append(stream_id)
+
+ state_all_rows = [
+ (stream_ids[-1], room_id)
+ for room_id, stream_ids in state_updates_by_room.items()
+ if len(stream_ids) >= _MAX_STATE_UPDATES_PER_ROOM
+ ]
+ state_all_updates: Iterable[Tuple[int, Tuple]] = (
+ (max_stream_id, (EventsStreamAllStateRow.TypeId, (room_id,)))
+ for (max_stream_id, room_id) in state_all_rows
+ )
+
+ # Any remaining state updates are sent individually.
+ state_all_rooms = {room_id for _, room_id in state_all_rows}
state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
+ if rest[0] not in state_all_rooms
)
ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
@@ -224,7 +267,11 @@ class EventsStream(Stream):
)
# we need to return a sorted list, so merge them together.
- updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
+ updates = list(
+ heapq.merge(
+ event_updates, state_all_updates, state_updates, ex_outliers_updates
+ )
+ )
return updates, upper_limit, limited
@classmethod
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 4046bdec69..7f5af5852c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -18,6 +18,7 @@ import attr
from synapse.replication.tcp.streams._base import (
Stream,
+ Token,
current_token_without_instance,
make_http_update_function,
)
@@ -47,7 +48,7 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
- current_token = current_token_without_instance(
+ self.current_token_func = current_token_without_instance(
federation_sender.get_current_token
)
update_function: Callable[
@@ -57,15 +58,21 @@ class FederationStream(Stream):
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
- current_token = self._stub_current_token
+ self.current_token_func = self._stub_current_token
else:
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
- current_token = self._stub_current_token
+ self.current_token_func = self._stub_current_token
- super().__init__(hs.get_instance_name(), current_token, update_function)
+ super().__init__(hs.get_instance_name(), update_function)
+
+ def current_token(self, instance_name: str) -> Token:
+ return self.current_token_func(instance_name)
+
+ def minimal_local_current_token(self) -> Token:
+ return self.current_token(self.local_instance_name)
@staticmethod
def _stub_current_token(instance_name: str) -> int:
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
index a8ce5ffd72..ad181d7e93 100644
--- a/synapse/replication/tcp/streams/partial_state.py
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import attr
-from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import _StreamFromIdGen
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
room_id: str
-class UnPartialStatedRoomStream(Stream):
+class UnPartialStatedRoomStream(_StreamFromIdGen):
"""
Stream to notify about rooms becoming un-partial-stated;
that is, when the background sync finishes such that we now have full state for
@@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_un_partial_stated_rooms_token,
store.get_un_partial_stated_rooms_from_stream,
+ store._un_partial_stated_rooms_stream_id_gen,
)
@@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
rejection_status_changed: bool
-class UnPartialStatedEventStream(Stream):
+class UnPartialStatedEventStream(_StreamFromIdGen):
"""
Stream to notify about events becoming un-partial-stated.
"""
@@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- store.get_un_partial_stated_events_token,
store.get_un_partial_stated_events_from_stream,
+ store._un_partial_stated_events_stream_id_gen,
)
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 8a617af599..a6ce787da1 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
)
- response = {"destinations": destinations, "total": total}
+ response = {
+ "destinations": [
+ {
+ "destination": r[0],
+ "retry_last_ts": r[1],
+ "retry_interval": r[2],
+ "failure_ts": r[3],
+ "last_successful_stream_ordering": r[4],
+ }
+ for r in destinations
+ ],
+ "total": total,
+ }
if (start + limit) < total:
response["next_token"] = str(start + len(destinations))
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 436718c8b2..0659f22a89 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -444,7 +444,7 @@ class RoomStateRestServlet(RestServlet):
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
- room_state = self._event_serializer.serialize_events(events.values(), now)
+ room_state = await self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
- return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
+ result = [
+ {
+ "event_id": ex[0],
+ "state_group": ex[1],
+ "depth": ex[2],
+ "received_ts": ex[3],
+ }
+ for ex in extremities
+ ]
+
+ return HTTPStatus.OK, {"count": len(extremities), "results": result}
class RoomEventContextServlet(RestServlet):
@@ -779,22 +789,22 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results = {
- "events_before": self._event_serializer.serialize_events(
+ "events_before": await self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
),
- "event": self._event_serializer.serialize_event(
+ "event": await self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
),
- "events_after": self._event_serializer.serialize_events(
+ "events_after": await self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
- "state": self._event_serializer.serialize_events(
+ "state": await self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 19780e4b4c..75d8a37ccf 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
- ret = {"users": users_media, "total": total}
+ ret = {
+ "users": [
+ {
+ "user_id": r[0],
+ "displayname": r[1],
+ "media_count": r[2],
+ "media_length": r[3],
+ }
+ for r in users_media
+ ],
+ "total": total,
+ }
if (start + limit) < total:
ret["next_token"] = start + len(users_media)
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 3eca4fe21f..5705f812a5 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -93,7 +93,7 @@ class EventRestServlet(RestServlet):
event = await self.event_handler.get_event(requester.user, None, event_id)
if event:
- result = self._event_serializer.serialize_event(
+ result = await self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index e7fe1332e7..5688d8593d 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -87,7 +87,7 @@ class NotificationsServlet(RestServlet):
"actions": pa.actions,
"ts": pa.received_ts,
"event": (
- self._event_serializer.serialize_event(
+ await self._event_serializer.serialize_event(
notif_events[pa.event_id],
now,
config=serialize_options,
diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index d578faa969..054a391f26 100644
--- a/synapse/rest/client/presence.py
+++ b/synapse/rest/client/presence.py
@@ -42,15 +42,13 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- self._use_presence = hs.config.server.use_presence
-
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
- if not self._use_presence:
+ if not self.hs.config.server.presence_enabled:
return 200, {"presence": "offline"}
if requester.user != user:
@@ -96,7 +94,7 @@ class PresenceStatusRestServlet(RestServlet):
except Exception:
raise SynapseError(400, "Unable to parse state")
- if self._use_presence:
+ if self.hs.config.server.track_presence:
await self.presence_handler.set_state(user, requester.device_id, state)
return 200, {}
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 553938ce9d..96f5726911 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -859,7 +859,7 @@ class RoomEventServlet(RestServlet):
# per MSC2676, /rooms/{roomId}/event/{eventId}, should return the
# *original* event, rather than the edited version
- event_dict = self._event_serializer.serialize_event(
+ event_dict = await self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
bundle_aggregations=aggregations,
@@ -911,25 +911,25 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
serializer_options = SerializeEventConfig(requester=requester)
results = {
- "events_before": self._event_serializer.serialize_events(
+ "events_before": await self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
- "event": self._event_serializer.serialize_event(
+ "event": await self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
- "events_after": self._event_serializer.serialize_events(
+ "events_after": await self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
- "state": self._event_serializer.serialize_events(
+ "state": await self._event_serializer.serialize_events(
event_context.state,
time_now,
config=serializer_options,
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 42bdd3bb10..33fde6c6f8 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -384,7 +384,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = self._event_serializer.serialize_event(
+ invite = await self._event_serializer.serialize_event(
room.invite, time_now, config=serialize_options
)
unsigned = dict(invite.get("unsigned", {}))
@@ -415,7 +415,7 @@ class SyncRestServlet(RestServlet):
"""
knocked = {}
for room in rooms:
- knock = self._event_serializer.serialize_event(
+ knock = await self._event_serializer.serialize_event(
room.knock, time_now, config=serialize_options
)
@@ -506,10 +506,10 @@ class SyncRestServlet(RestServlet):
event.room_id,
)
- serialized_state = self._event_serializer.serialize_events(
+ serialized_state = await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
- serialized_timeline = self._event_serializer.serialize_events(
+ serialized_timeline = await self._event_serializer.serialize_events(
timeline_events,
time_now,
config=serialize_options,
diff --git a/synapse/server.py b/synapse/server.py
index 71ead524d6..5bfb4ba4eb 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -786,7 +786,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_event_client_serializer(self) -> EventClientSerializer:
- return EventClientSerializer()
+ return EventClientSerializer(self)
@cache_in_self
def get_password_policy_handler(self) -> PasswordPolicyHandler:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 81f661160c..a4e7048368 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -35,7 +35,6 @@ from typing import (
Tuple,
Type,
TypeVar,
- Union,
cast,
overload,
)
@@ -421,6 +420,16 @@ class LoggingTransaction:
self._do_execute(self.txn.execute, sql, parameters)
def executemany(self, sql: str, *args: Any) -> None:
+ """Repeatedly execute the same piece of SQL with different parameters.
+
+ See https://peps.python.org/pep-0249/#executemany. Note in particular that
+
+ > Use of this method for an operation which produces one or more result sets
+ > constitutes undefined behavior
+
+ so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
+ DELETE FROM... RETURNING.
+ """
# TODO: we should add a type for *args here. Looking at Cursor.executemany
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
@@ -606,13 +615,16 @@ class DatabasePool:
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = await self.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
+ updates = cast(
+ List[Tuple[str]],
+ await self.simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ ),
)
- background_update_names = [x["update_name"] for x in updates]
+ background_update_names = [x[0] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
@@ -1044,43 +1056,20 @@ class DatabasePool:
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- @overload
- async def execute(
- self, desc: str, decoder: Literal[None], query: str, *args: Any
- ) -> List[Tuple[Any, ...]]:
- ...
-
- @overload
- async def execute(
- self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
- ) -> R:
- ...
-
- async def execute(
- self,
- desc: str,
- decoder: Optional[Callable[[Cursor], R]],
- query: str,
- *args: Any,
- ) -> Union[List[Tuple[Any, ...]], R]:
+ async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
Args:
desc: description of the transaction, for logging and metrics
- decoder - The function which can resolve the cursor results to
- something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
- def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
+ def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
+ return txn.fetchall()
return await self.runInteraction(desc, interaction)
@@ -1804,9 +1793,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows, returning the result as a list of tuples.
Args:
table: the table name
@@ -1817,8 +1806,7 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
Returns:
- A list of dictionaries, one per result row, each a mapping between the
- column names from `retcols` and that column's value for the row.
+ A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
@@ -1836,9 +1824,9 @@ class DatabasePool:
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows, returning the result as a list of tuples.
Args:
txn: Transaction object
@@ -1849,8 +1837,7 @@ class DatabasePool:
retcols: the names of the columns to return
Returns:
- A list of dictionaries, one per result row, each a mapping between the
- column names from `retcols` and that column's value for the row.
+ A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1863,7 +1850,7 @@ class DatabasePool:
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
async def simple_select_many_batch(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 39498d52c6..d7482a1f4e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
+ extra_tables=[
+ ("account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
@@ -283,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
- ) -> Dict[str, JsonDict]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"],
+ ) -> Dict[str, JsonMapping]:
+ rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=["account_data_type", "content"],
+ ),
)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in rows
}
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 073a99cd84..fa7d1c469a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
- results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state.value}, ["as_id"]
+ results = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="application_services_state",
+ keyvalues={"state": state.value},
+ retcols=("as_id",),
+ ),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
- for res in results:
+ for (as_id,) in results:
for service in as_list:
- if service.id == res["as_id"]:
+ if service.id == as_id:
services.append(service)
return services
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..4d0470ffd9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
+ EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
+ elif row.type == EventsStreamAllStateRow.TypeId:
+ assert isinstance(data, EventsStreamAllStateRow)
+ # Similar to the above, but the entire caches are invalidated. This is
+ # unfortunate for the membership caches, but should recover quickly.
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
+ self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
+ self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..711fdddd4e 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ "_censor_redactions_fetch", sql, before_ts, 100
)
updates = []
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8be1511859..c006129625 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
if device_id is not None:
keyvalues["device_id"] = device_id
- res = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ res = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
)
return {
- (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
- user_id=d["user_id"],
- device_id=d["device_id"],
- ip=d["ip"],
- user_agent=d["user_agent"],
- last_seen=d["last_seen"],
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=ip,
+ user_agent=user_agent,
+ last_seen=last_seen,
)
- for d in res
+ for user_id, ip, user_agent, device_id, last_seen in res
}
async def _get_user_ip_and_agents_from_database(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1faa6f04b2..3e7425d4a6 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -478,18 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
- ROW_ID_NAME = self.database_engine.row_id_name
-
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
limit_statement = "" if limit is None else f"LIMIT {limit}"
sql = f"""
- DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
- SELECT {ROW_ID_NAME} FROM device_inbox
- WHERE user_id = ? AND device_id = ? AND stream_id <= ?
- {limit_statement}
+ DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
+ SELECT MAX(stream_id) FROM (
+ SELECT stream_id FROM device_inbox
+ WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+ ORDER BY stream_id
+ {limit_statement}
+ ) AS q1
)
"""
- txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
return txn.rowcount
count = await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fc23d18eba..49edbb9e06 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
- async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
+ async def get_devices_by_user(
+ self, user_id: str
+ ) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
- and "display_name" for each device.
+ and "display_name" for each device. Display name may be null.
"""
- devices = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user",
+ devices = cast(
+ List[Tuple[str, str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user",
+ ),
)
- return {d["device_id"]: d for d in devices}
+ return {
+ d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
+ for d in devices
+ }
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
- return await self.db_pool.simple_select_list(
- table="device_auth_providers",
- keyvalues={
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- retcols=("user_id", "device_id"),
- desc="get_devices_by_auth_provider_session_id",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ ),
)
@trace
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
- devices = await self.db_pool.simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={"user_id": user_id},
- retcols=("device_id", "content"),
- desc="get_cached_devices_for_user",
+ devices = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("device_id", "content"),
+ desc="get_cached_devices_for_user",
+ ),
)
- return {
- device["device_id"]: db_to_json(device["content"]) for device in devices
- }
+ return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@@ -882,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
- None,
sql,
from_key,
to_key,
@@ -966,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
- "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ "get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
@@ -1080,7 +1091,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- row_tuples = cast(
+ rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
@@ -1090,11 +1101,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)
-
- return {row[0] for row in row_tuples}
else:
rows = cast(
- List[Dict[str, str]],
+ List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
@@ -1103,7 +1112,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
),
)
- return {row["user_id"] for row in rows}
+ return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index aac4cfb054..ad904a26a6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = await self.db_pool.simple_select_list(
- table="e2e_room_keys",
- keyvalues=keyvalues,
- retcols=(
- "user_id",
- "room_id",
- "session_id",
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
+ rows = cast(
+ List[Tuple[str, str, int, int, int, str]],
+ await self.db_pool.simple_select_list(
+ table="e2e_room_keys",
+ keyvalues=keyvalues,
+ retcols=(
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ desc="get_e2e_room_keys",
),
- desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
- for row in rows:
- room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
- room_entry["sessions"][row["session_id"]] = {
- "first_message_index": row["first_message_index"],
- "forwarded_count": row["forwarded_count"],
+ for (
+ room_id,
+ session_id,
+ first_message_index,
+ forwarded_count,
+ is_verified,
+ session_data,
+ ) in rows:
+ room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
+ room_entry["sessions"][session_id] = {
+ "first_message_index": first_message_index,
+ "forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
- "is_verified": bool(row["is_verified"]),
- "session_data": db_to_json(row["session_data"]),
+ "is_verified": bool(is_verified),
+ "session_data": db_to_json(session_data),
}
return sessions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f13d776b0d..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,6 +24,7 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -155,7 +156,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
- None,
sql,
now_stream_id,
user_id,
@@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str, int]]
+ self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
@@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A tuple pf:
+ A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
- A copy of the input which has not been fulfilled.
+ A copy of the input which has not been fulfilled. The returned counts
+ may be less than the input counts. In this case, the returned counts
+ are the number of claims that were not fulfilled.
"""
-
- @trace
- def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that don't support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- sql = """
- SELECT key_id, key_json FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- """
-
- txn.execute(sql, (user_id, device_id, algorithm, count))
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self.db_pool.simple_delete_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- column="key_id",
- values=[otk_row[0] for otk_row in otk_rows],
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- },
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
- @trace
- def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- # We can use RETURNING to do the fetch and DELETE in once step.
- sql = """
- DELETE FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- AND key_id IN (
- SELECT key_id FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- )
- RETURNING key_id, key_json
- """
-
- txn.execute(
- sql,
- (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
- )
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
- for user_id, device_id, algorithm, count in query_list:
- if self.database_engine.supports_returning:
- # If we support RETURNING clause we can use a single query that
- # allows us to use autocommit mode.
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
- db_autocommit = True
- else:
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
- db_autocommit = False
-
- claim_rows = await self.db_pool.runInteraction(
+ if isinstance(self.database_engine, PostgresEngine):
+ # If we can use execute_values we can use a single batch query
+ # in autocommit mode.
+ unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+ for user_id, device_id, algorithm, count in query_list:
+ unfulfilled_claim_counts[user_id, device_id, algorithm] = count
+
+ bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
- _claim_e2e_one_time_key,
- user_id,
- device_id,
- algorithm,
- count,
- db_autocommit=db_autocommit,
+ self._claim_e2e_one_time_keys_bulk,
+ query_list,
+ db_autocommit=True,
)
- if claim_rows:
+
+ for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- for claim_row in claim_rows:
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+ unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
# Did we get enough OTKs?
- count -= len(claim_rows)
- if count:
- missing.append((user_id, device_id, algorithm, count))
+ missing = [
+ (user, device, alg, count)
+ for (user, device, alg), count in unfulfilled_claim_counts.items()
+ if count > 0
+ ]
+ else:
+ for user_id, device_id, algorithm, count in query_list:
+ claim_rows = await self.db_pool.runInteraction(
+ "claim_e2e_one_time_keys",
+ self._claim_e2e_one_time_key_simple,
+ user_id,
+ device_id,
+ algorithm,
+ count,
+ db_autocommit=False,
+ )
+ if claim_rows:
+ device_results = results.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ )
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
@@ -1261,6 +1193,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
+ if isinstance(self.database_engine, PostgresEngine):
+ return await self.db_pool.runInteraction(
+ "_claim_e2e_fallback_keys_bulk",
+ self._claim_e2e_fallback_keys_bulk_txn,
+ query_list,
+ db_autocommit=True,
+ )
+ # Use an UPDATE FROM... RETURNING combined with a VALUES block to do
+ # everything in one query. Note: this is also supported in SQLite 3.33.0,
+ # (see https://www.sqlite.org/lang_update.html#update_from), but we do not
+ # have an equivalent of psycopg2's execute_values to do this in one query.
+ else:
+ return await self._claim_e2e_fallback_keys_simple(query_list)
+
+ def _claim_e2e_fallback_keys_bulk_txn(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Efficient implementation of claim_e2e_fallback_keys for Postgres.
+
+ Safe to autocommit: this is a single query.
+ """
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+
+ sql = """
+ WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
+ VALUES ?
+ )
+ UPDATE e2e_fallback_keys_json k
+ SET used = used OR mark_as_used
+ FROM claims
+ WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
+ RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
+ """
+ claimed_keys = cast(
+ List[Tuple[str, str, str, str, str]],
+ txn.execute_values(sql, query_list),
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
+ device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
+ return results
+
+ async def _claim_e2e_fallback_keys_simple(
+ self,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
@@ -1303,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results
+ @trace
+ def _claim_e2e_one_time_key_simple(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
+ """Claim OTK for device for DBs that don't support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ sql = """
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+ @trace
+ def _claim_e2e_one_time_keys_bulk(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, int]],
+ ) -> List[Tuple[str, str, str, str, str]]:
+ """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+ Args:
+ query_list: Collection of tuples (user_id, device_id, algorithm, count)
+ as passed to claim_e2e_one_time_keys.
+
+ Returns:
+ A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+ for each OTK claimed.
+ """
+ sql = """
+ WITH claims(user_id, device_id, algorithm, claim_count) AS (
+ VALUES ?
+ ), ranked_keys AS (
+ SELECT
+ user_id, device_id, algorithm, key_id, claim_count,
+ ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+ FROM e2e_one_time_keys_json
+ JOIN claims USING (user_id, device_id, algorithm)
+ )
+ DELETE FROM e2e_one_time_keys_json k
+ WHERE (user_id, device_id, algorithm, key_id) IN (
+ SELECT user_id, device_id, algorithm, key_id
+ FROM ranked_keys
+ WHERE r <= claim_count
+ )
+ RETURNING user_id, device_id, algorithm, key_id, key_json;
+ """
+ otk_rows = cast(
+ List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, _, _, _ in otk_rows:
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return otk_rows
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4f80ce75cc..f1b0991503 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
- rows = await self.db_pool.simple_select_list(
- table="federation_inbound_events_staging",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "event_json"),
- desc="prune_staged_events_in_room_fetch",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ ),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
- for row in rows:
- event_id = row["event_id"]
+ for event_id, event_json in rows:
seen_events.add(event_id)
- event_d = db_to_json(row["event_json"])
+ event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ef6766b5e0..3c1492e3ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2267,35 +2267,59 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- # From the events passed in, add all of the prev events as backwards extremities.
- # Ignore any events that are already backwards extrems or outliers.
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- # 1. Don't add an event as a extremity again if we already persisted it
- # as a non-outlier.
- # 2. Don't add an outlier as an extremity if it has no prev_events
- " AND NOT EXISTS ("
- " SELECT 1 FROM events"
- " LEFT JOIN event_edges edge"
- " ON edge.event_id = events.event_id"
- " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)"
- " )"
+
+ room_id = events[0].room_id
+
+ potential_backwards_extremities = {
+ e_id
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ }
+
+ if not potential_backwards_extremities:
+ return
+
+ existing_events_outliers = self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=potential_backwards_extremities,
+ keyvalues={"outlier": False},
+ retcols=("event_id",),
)
- txn.execute_batch(
- query,
- [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id)
- for ev in events
- for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ],
+ potential_backwards_extremities.difference_update(
+ e for e, in existing_events_outliers
)
+ if potential_backwards_extremities:
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="event_backward_extremities",
+ key_names=("room_id", "event_id"),
+ key_values=[(room_id, ev) for ev in potential_backwards_extremities],
+ value_names=(),
+ value_values=(),
+ )
+
+ # Record the stream orderings where we have new gaps.
+ gap_events = [
+ (room_id, self._instance_name, ev.internal_metadata.stream_ordering)
+ for ev in events
+ if any(
+ e_id in potential_backwards_extremities
+ for e_id in ev.prev_event_ids()
+ )
+ ]
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="timeline_gaps",
+ keys=("room_id", "instance_name", "stream_ordering"),
+ values=gap_events,
+ )
+
# Delete all these events that we've already fetched and now know that their
# prev events are the new backwards extremeties.
query = (
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c5fce1c82b..0061805150 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
- # We need to pass execute a dummy function to handle the txn's result otherwise
- # it tries to call fetchall() on it and fails because there's no result to fetch.
- await self.db_pool.execute(
+ await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
- lambda txn: None,
- "ANALYZE events(stream_ordering2)",
+ lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index f851bff604..0ba84b1469 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List
+from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
- ) -> List[Dict[str, Any]]:
- """Get list of forward extremities for a room."""
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
+ """
+ Get list of forward extremities for a room.
+
+ Returns:
+ A list of tuples of event_id, state_group, depth, and received_ts.
+ """
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8af638d60f..5bf864c1fb 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2096,12 +2096,6 @@ class EventsWorkerStore(SQLBaseStore):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
- DELETE FROM event_txn_id
- WHERE inserted_ts < ?
- """
- txn.execute(sql, (one_day_ago,))
-
- sql = """
DELETE FROM event_txn_id_device_id
WHERE inserted_ts < ?
"""
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index 654f924019..60621edeef 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, FrozenSet
+from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
- enabled = await self.db_pool.simple_select_list(
- "per_user_experimental_features",
- {"user_id": user_id, "enabled": True},
- ["feature"],
+ enabled = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="per_user_experimental_features",
+ keyvalues={"user_id": user_id, "enabled": True},
+ retcols=("feature",),
+ ),
)
- return frozenset(feature["feature"] for feature in enabled)
+ return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ea797864b9..ce88772f9e 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_list(
- table="server_keys_json",
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
@@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
- rows.sort(key=lambda r: r["ts_added_ms"])
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2e6b176bd2..aeb3db596c 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
- rows = await self.db_pool.simple_select_list(
- "local_media_repository_thumbnails",
- {"media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "local_media_repository_thumbnails",
+ {"media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_local_media_thumbnails",
),
- desc="get_local_media_thumbnails",
)
return [
ThumbnailInfo(
- width=row["thumbnail_width"],
- height=row["thumbnail_height"],
- method=row["thumbnail_method"],
- type=row["thumbnail_type"],
- length=row["thumbnail_length"],
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
) -> List[ThumbnailInfo]:
- rows = await self.db_pool.simple_select_list(
- "remote_media_cache_thumbnails",
- {"media_origin": origin, "media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "remote_media_cache_thumbnails",
+ {"media_origin": origin, "media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_remote_media_thumbnails",
),
- desc="get_remote_media_thumbnails",
)
return [
ThumbnailInfo(
- width=row["thumbnail_width"],
- height=row["thumbnail_height"],
- method=row["thumbnail_method"],
- type=row["thumbnail_type"],
- length=row["thumbnail_length"],
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -652,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -666,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
+ * The filesystem ID.
+ """
+
+ sql = """
+ SELECT media_origin, media_id, filesystem_id
+ FROM remote_media_cache
+ WHERE last_access_ts < ?
"""
- sql = (
- "SELECT media_origin, media_id, filesystem_id"
- " FROM remote_media_cache"
- " WHERE last_access_ts < ?"
- )
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -679,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
- return await self.db_pool.execute(
- "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+ return cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index f5356e7f80..22025eca56 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -179,46 +179,44 @@ class PushRulesWorkerStore(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
- rows = await self.db_pool.simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
+ rows = cast(
+ List[Tuple[str, int, int, str, str]],
+ await self.db_pool.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_for_user",
),
- desc="get_push_rules_for_user",
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(
- [
- (
- row["rule_id"],
- row["priority_class"],
- row["conditions"],
- row["actions"],
- )
- for row in rows
- ],
+ [(row[0], row[1], row[3], row[4]) for row in rows],
enabled_map,
self.hs.config.experimental,
)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
- results = await self.db_pool.simple_select_list(
- table="push_rules_enable",
- keyvalues={"user_name": user_id},
- retcols=("rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
+ results = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ ),
)
- return {r["rule_id"]: bool(r["enabled"]) for r in results}
+ return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c7eb7fc478..a6a1671bd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
async def get_throttle_params_by_room(
self, pusher_id: int
) -> Dict[str, ThrottleParams]:
- res = await self.db_pool.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
+ res = cast(
+ List[Tuple[str, Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ ),
)
params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"],
- row["throttle_ms"],
+ for room_id, last_sent_ts, throttle_ms in res:
+ params_by_room[room_id] = ThrottleParams(
+ last_sent_ts or 0, throttle_ms or 0
)
return params_by_room
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b2645ab43c..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from immutabledict import immutabledict
+
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ MultiWriterStreamToken,
+ PersistedPosition,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
- max_value=max_receipts_stream_id,
+ max_value=max_receipts_stream_id.stream,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
- def get_max_receipt_stream_id(self) -> int:
+ def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
- return self._receipts_id_gen.get_current_token()
+
+ min_pos = self._receipts_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._receipts_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return MultiWriterStreamToken(
+ stream=min_pos, instance_map=immutabledict(positions)
+ )
+
+ def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+ return self._receipts_id_gen.get_current_token_for_writer(instance_name)
def get_last_unthreaded_receipt_for_user_txn(
self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Iterable[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
- room_ids, from_key
+ room_ids, from_key.stream
)
results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
- if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ if not self._receipts_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ ):
return []
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(tree=True)
async def _get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
- sql = (
- "SELECT receipt_type, user_id, event_id, data"
- " FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id > ? AND stream_id <= ?"
- )
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, from_key, to_key))
- else:
- sql = (
- "SELECT receipt_type, user_id, event_id, data"
- " FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id <= ?"
+ txn.execute(
+ sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
)
+ else:
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
+ room_id = ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, to_key))
+ txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
- return cast(List[Tuple[str, str, str, str]], txn.fetchall())
+ return [
+ (receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3,
)
async def _get_linearized_receipts_for_rooms(
- self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Collection[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [from_key, to_key] + list(args))
+ txn.execute(
+ sql + clause,
+ [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+ )
else:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [to_key] + list(args))
+ txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return cast(
- List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
- )
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
- self, to_key: int, from_key: Optional[int] = None
+ self,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [from_key, to_key])
+ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [to_key])
+ txn.execute(sql, [to_key.get_max_stream_pos()])
- return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
+ return [
+ (room_id, receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues,
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[int]:
+ ) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- return stream_id
+ return PersistedPosition(self._instance_name, stream_id)
async def _insert_graph_receipt(
self,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9e8643ae4d..e09ab21593 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidValidationSession:
+ address: str
+ """address of the 3pid"""
+ medium: str
+ """medium of the 3pid"""
+ client_secret: str
+ """a secret provided by the client for this validation session"""
+ session_id: str
+ """ID of the validation session"""
+ last_send_attempt: int
+ """a number serving to dedupe send attempts for this session"""
+ validated_at: Optional[int]
+ """timestamp of when this session was validated if so"""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -855,13 +871,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
Tuples of (auth_provider, external_id)
"""
- res = await self.db_pool.simple_select_list(
- table="user_external_ids",
- keyvalues={"user_id": mxid},
- retcols=("auth_provider", "external_id"),
- desc="get_external_ids_by_user",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ ),
)
- return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
@@ -997,13 +1015,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
- results = await self.db_pool.simple_select_list(
- "user_threepids",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address", "validated_at", "added_at"],
- desc="user_get_threepids",
+ results = cast(
+ List[Tuple[str, str, int, int]],
+ await self.db_pool.simple_select_list(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
+ ),
)
- return [ThreepidResult(**r) for r in results]
+ return [
+ ThreepidResult(
+ medium=r[0],
+ address=r[1],
+ validated_at=r[2],
+ added_at=r[3],
+ )
+ for r in results
+ ]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@@ -1042,7 +1071,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid",
)
- async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
+ async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
@@ -1051,15 +1080,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for
Returns:
- List of dictionaries containing the following keys:
- medium (str): The medium of the threepid (e.g "email")
- address (str): The address of the threepid (e.g "bob@example.com")
- """
- return await self.db_pool.simple_select_list(
- table="user_threepid_id_server",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address"],
- desc="user_get_bound_threepids",
+ List of tuples of two strings:
+ medium: The medium of the threepid (e.g "email")
+ address: The address of the threepid (e.g "bob@example.com")
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ ),
)
async def remove_user_bound_threepid(
@@ -1156,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
@@ -1171,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering
Returns:
- A dict containing the following:
- * address - address of the 3pid
- * medium - medium of the 3pid
- * client_secret - a secret provided by the client for this validation session
- * session_id - ID of the validation session
- * send_attempt - a number serving to dedupe send attempts for this session
- * validated_at - timestamp of when this session was validated if so
-
- Otherwise None if a validation session is not found
+ A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
@@ -1198,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1213,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ return ThreepidValidationSession(
+ address=row[0],
+ session_id=row[1],
+ medium=row[2],
+ client_secret=row[3],
+ last_send_attempt=row[4],
+ validated_at=row[5],
+ )
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7f40e2c446..419b2c7a22 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, StreamKeyType, StreamToken
+from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_key=next_key,
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
@@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_list_txn(
- txn=txn,
- table="event_relations",
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_list_txn(
+ txn=txn,
+ table="event_relations",
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9d24d2c347..3e8fcf1975 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
room_servers: Dict[str, PartialStateResyncInfo] = {}
- rows = await self.db_pool.simple_select_list(
- table="partial_state_rooms",
- keyvalues={},
- retcols=("room_id", "joined_via"),
- desc="get_server_which_served_partial_join",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- joined_via = row["joined_via"]
+ for room_id, joined_via in rows:
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
- rows = await self.db_pool.simple_select_list(
- "partial_state_rooms_servers",
- keyvalues=None,
- retcols=("room_id", "server_name"),
- desc="get_partial_state_rooms",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- server_name = row["server_name"]
+ for room_id, server_name in rows:
entry = room_servers.get(room_id)
if entry is None:
# There is a foreign key constraint which enforces that every room_id in
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3a87eba430..1ed7f2d0ef 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -482,6 +482,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
desc="get_local_users_in_room",
)
+ async def get_local_users_related_to_room(
+ self, room_id: str
+ ) -> List[Tuple[str, str]]:
+ """
+ Retrieves a list of the current roommembers who are local to the server and their membership status.
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="local_current_membership",
+ keyvalues={"room_id": room_id},
+ retcols=("user_id", "membership"),
+ desc="get_local_users_in_room",
+ ),
+ )
+
async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
"""
Check whether a given local user is currently joined to the given room.
@@ -940,7 +956,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, membership, room_id, like_clause
+ "is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@@ -1070,13 +1086,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
- rows = await self.db_pool.simple_select_list(
- "current_state_events",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "membership"),
- desc="has_completed_background_updates",
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ ),
)
- return {row["event_id"]: row["membership"] for row in rows}
+ return dict(rows)
# TODO This returns a mutable object, which is generally confusing when using a cache.
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
@@ -1165,7 +1184,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
- rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+ rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1d69c4a5f0..dbde9130c6 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -26,6 +26,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = await self.db_pool.execute(
- "search_msgs", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id).
+ results = cast(
+ List[Tuple[Union[int, float], str, str]],
+ await self.db_pool.execute("search_msgs", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ {"event": event_map[r[2]], "rank": r[0]}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
@@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
- origin_server_ts, stream_ordering, room_id, event_id
+ room_id, event_id, origin_server_ts, stream_ordering
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
@@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int`
args.append(limit)
- results = await self.db_pool.execute(
- "search_rooms", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
+ results = cast(
+ List[Tuple[Union[int, float], str, str, int, int]],
+ await self.db_pool.execute("search_rooms", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
+ "event": event_map[r[2]],
+ "rank": r[0],
+ "pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5b2d0ba870..e96c9b0486 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
@@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by
+
Returns:
- A list of user dicts and an integer representing the total number of
- users that exist given this query
+ A tuple of:
+ A list of tuples of user information (the user ID, displayname,
+ total number of media, total length of media) and
+
+ An integer representing the total number of users that exist
+ given this query
"""
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []
@@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
+ users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ea06e4eee0..2225f8272d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, room_id, stream_key
+ "get_current_topological_token", sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1616,3 +1616,49 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcol="instance_name",
desc="get_name_from_instance_id",
)
+
+ async def get_timeline_gaps(
+ self,
+ room_id: str,
+ from_token: Optional[RoomStreamToken],
+ to_token: RoomStreamToken,
+ ) -> Optional[RoomStreamToken]:
+ """Check if there is a gap, and return a token that marks the position
+ of the gap in the stream.
+ """
+
+ sql = """
+ SELECT instance_name, stream_ordering
+ FROM timeline_gaps
+ WHERE room_id = ? AND ? < stream_ordering AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ """
+
+ rows = await self.db_pool.execute(
+ "get_timeline_gaps",
+ sql,
+ room_id,
+ from_token.stream if from_token else 0,
+ to_token.get_max_stream_pos(),
+ )
+
+ if not rows:
+ return None
+
+ positions = [
+ PersistedEventPosition(instance_name, stream_ordering)
+ for instance_name, stream_ordering in rows
+ ]
+ if from_token:
+ positions = [p for p in positions if p.persisted_after(from_token)]
+
+ positions = [p for p in positions if not p.persisted_after(to_token)]
+
+ if positions:
+ # We return a stream token that ensures the event *at* the position
+ # of the gap is included (as the gap is *before* the persisted
+ # event).
+ last_position = positions[-1]
+ return RoomStreamToken(stream=last_position.stream - 1)
+
+ return None
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 61403a98cf..7deda7790e 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content.
"""
- rows = await self.db_pool.simple_select_list(
- "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_list(
+ "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ ),
)
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
+ for room_id, tag, content in rows:
+ room_tags = tags_by_room.setdefault(room_id, {})
+ room_tags[tag] = db_to_json(content)
return tags_by_room
async def get_all_updated_tags(
@@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A mapping of tags to tag content.
"""
- rows = await self.db_pool.simple_select_list(
- table="room_tags",
- keyvalues={"user_id": user_id, "room_id": room_id},
- retcols=("tag", "content"),
- desc="get_tags_for_room",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="room_tags",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=("tag", "content"),
+ desc="get_tags_for_room",
+ ),
)
- return {row["tag"]: db_to_json(row["content"]) for row in rows}
+ return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c4a6475060..fecddb4144 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
+ int,
+ ]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
@@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
- A tuple of a list of mappings from destination to information
+ A tuple of a list of tuples of destination information:
+ * destination
+ * retry_last_ts
+ * retry_interval
+ * failure_ts
+ * last_successful_stream_ordering
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[
+ Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
+ ],
+ int,
+ ]:
order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS:
@@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
- destinations = self.db_pool.cursor_to_dict(txn)
+ destinations = cast(
+ List[
+ Tuple[
+ str, Optional[int], Optional[int], Optional[int], Optional[int]
+ ]
+ ],
+ txn.fetchall(),
+ )
return destinations, count
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 919c66f553..8ab7c42c4a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db_pool.simple_select_list(
- table="ui_auth_sessions_credentials",
- keyvalues={"session_id": session_id},
- retcols=("stage_type", "result"),
- desc="get_completed_ui_auth_stages",
- ):
- results[row["stage_type"]] = db_to_json(row["result"])
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ),
+ )
+ for stage_type, result in rows:
+ results[stage_type] = db_to_json(result)
return results
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
- rows = await self.db_pool.simple_select_list(
- table="ui_auth_sessions_ips",
- keyvalues={"session_id": session_id},
- retcols=("user_agent", "ip"),
- desc="get_user_agents_ips_to_ui_auth_session",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ ),
)
- return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 23eb92c514..a9f5d68b63 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine")
results = cast(
- List[UserProfile],
- await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
- ),
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.db_pool.execute("search_user_dir", sql, *args),
)
limited = len(results) > limit
- return {"limited": limited, "results": results[0:limit]}
+ return {
+ "limited": limited,
+ "results": [
+ {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
+ for r in results[0:limit]
+ ],
+ }
def _filter_text_for_index(text: str) -> str:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 6ff533a129..0f9c550b27 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None:
rows = await self.db_pool.execute(
"_background_deduplicate_state",
- None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 09d2a8c5b3..182e429174 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db_pool.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
+ delta_ids = cast(
+ List[Tuple[str, str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ ),
)
return _GetStateGroupDelta(
prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ {
+ (event_type, state_key): event_id
+ for event_type, state_key, event_id in delta_ids
+ },
)
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 5b50bd66bc..158b528dce 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 82 # remember to update the list below when updating
+SCHEMA_VERSION = 83 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -121,6 +121,9 @@ Changes in SCHEMA_VERSION = 81
Changes in SCHEMA_VERSION = 82
- The insertion_events, insertion_event_extremities, insertion_event_edges, and
batch_events tables are no longer purged in preparation for their removal.
+
+Changes in SCHEMA_VERSION = 83
+ - The event_txn_id is no longer used.
"""
diff --git a/synapse/storage/schema/main/delta/82/05gaps.sql b/synapse/storage/schema/main/delta/82/05gaps.sql
new file mode 100644
index 0000000000..6813b488ca
--- /dev/null
+++ b/synapse/storage/schema/main/delta/82/05gaps.sql
@@ -0,0 +1,25 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Records when we see a "gap in the timeline", due to missing events over
+-- federation. We record this so that we can tell clients there is a gap (by
+-- marking the timeline section of a sync request as limited).
+CREATE TABLE IF NOT EXISTS timeline_gaps (
+ room_id TEXT NOT NULL,
+ instance_name TEXT NOT NULL,
+ stream_ordering BIGINT NOT NULL
+);
+
+CREATE INDEX timeline_gaps_room_id ON timeline_gaps(room_id, stream_ordering);
diff --git a/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
new file mode 100644
index 0000000000..6c7ad0fd37
--- /dev/null
+++ b/synapse/storage/schema/main/delta/83/03_instance_name_receipts.sql.sqlite
@@ -0,0 +1,17 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This already exists on Postgres.
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d2c874b9a8..9c3eafb562 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def get_minimal_local_current_token(self) -> int:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+
+ @abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
@@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token()
+ def get_minimal_local_current_token(self) -> int:
+ return self.get_current_token()
+
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers.
@@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1
+ # The maximum position of the local instance. This can be higher than
+ # the corresponding position in `current_positions` table when there are
+ # no active writes in progress.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1
)
+ # For the case where `stream_positions` is not up to date,
+ # `_persisted_upto_position` may be higher.
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._persisted_upto_position
+ )
+
+ # Bump our local maximum position now that we've loaded things from the
+ # DB.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
if not writers:
# If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own
@@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance == self._instance_name:
self._current_positions[instance] = stream_id
+ if self._writers:
+ # If we have explicit writers then make sure that each instance has
+ # a position.
+ for writer in self._writers:
+ self._current_positions.setdefault(
+ writer, self._persisted_upto_position
+ )
+
cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int:
@@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
+ self._max_position_of_local_instance = max(
+ curr, new_cur, self._max_position_of_local_instance
+ )
self._add_persisted_position(next_id)
@@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(
+ if self._instance_name == instance_name:
+ return self._return_factor * self._max_position_of_local_instance
+
+ pos = self._current_positions.get(
instance_name, self._persisted_upto_position
)
+ # We want to return the maximum "current token" that we can for a
+ # writer, this helps ensure that streams progress as fast as
+ # possible.
+ pos = max(pos, self._persisted_upto_position)
+
+ return self._return_factor * pos
+
+ def get_minimal_local_current_token(self) -> int:
+ with self._lock:
+ return self._return_factor * self._current_positions.get(
+ self._instance_name, self._persisted_upto_position
+ )
+
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
@@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+ # Advance our local max position.
+ self._max_position_of_local_instance = max(
+ self._max_position_of_local_instance, self._persisted_upto_position
+ )
+
+ if not self._unfinished_ids and not self._in_flight_fetches:
+ # If we don't have anything in flight, it's safe to advance to the
+ # max seen stream ID.
+ self._max_position_of_local_instance = max(
+ self._max_seen_allocated_stream_id, self._max_position_of_local_instance
+ )
+
# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 609a0978a9..d0bb83b184 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import StreamKeyType, StreamToken
+from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -111,7 +111,7 @@ class EventSources:
room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 09a88c86a7..4c5b26ad93 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -695,6 +695,90 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return "s%d" % (self.stream,)
+@attr.s(frozen=True, slots=True, order=False)
+class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
+ """A basic stream token class for streams that supports multiple writers."""
+
+ @classmethod
+ async def parse(cls, store: "DataStore", string: str) -> "MultiWriterStreamToken":
+ try:
+ if string[0].isdigit():
+ return cls(stream=int(string))
+ if string[0] == "m":
+ parts = string[1:].split("~")
+ stream = int(parts[0])
+
+ instance_map = {}
+ for part in parts[1:]:
+ key, value = part.split(".")
+ instance_id = int(key)
+ pos = int(value)
+
+ instance_name = await store.get_name_from_instance_id(instance_id)
+ instance_map[instance_name] = pos
+
+ return cls(
+ stream=stream,
+ instance_map=immutabledict(instance_map),
+ )
+ except CancelledError:
+ raise
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid stream token %r" % (string,))
+
+ async def to_string(self, store: "DataStore") -> str:
+ if self.instance_map:
+ entries = []
+ for name, pos in self.instance_map.items():
+ if pos <= self.stream:
+ # Ignore instances who are below the minimum stream position
+ # (we might know they've advanced without seeing a recent
+ # write from them).
+ continue
+
+ instance_id = await store.get_id_for_instance(name)
+ entries.append(f"{instance_id}.{pos}")
+
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ else:
+ return str(self.stream)
+
+ @staticmethod
+ def is_stream_position_in_range(
+ low: Optional["AbstractMultiWriterStreamToken"],
+ high: Optional["AbstractMultiWriterStreamToken"],
+ instance_name: Optional[str],
+ pos: int,
+ ) -> bool:
+ """Checks if a given persisted position is between the two given tokens.
+
+ If `instance_name` is None then the row was persisted before multi
+ writer support.
+ """
+
+ if low:
+ if instance_name:
+ low_stream = low.instance_map.get(instance_name, low.stream)
+ else:
+ low_stream = low.stream
+
+ if pos <= low_stream:
+ return False
+
+ if high:
+ if instance_name:
+ high_stream = high.instance_map.get(instance_name, high.stream)
+ else:
+ high_stream = high.stream
+
+ if high_stream < pos:
+ return False
+
+ return True
+
+
class StreamKeyType(Enum):
"""Known stream types.
@@ -776,7 +860,9 @@ class StreamToken:
)
presence_key: int
typing_key: int
- receipt_key: int
+ receipt_key: MultiWriterStreamToken = attr.ib(
+ validator=attr.validators.instance_of(MultiWriterStreamToken)
+ )
account_data_key: int
push_rules_key: int
to_device_key: int
@@ -799,8 +885,31 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
+
+ (
+ room_key,
+ presence_key,
+ typing_key,
+ receipt_key,
+ account_data_key,
+ push_rules_key,
+ to_device_key,
+ device_list_key,
+ groups_key,
+ un_partial_stated_rooms_key,
+ ) = keys
+
return cls(
- await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+ room_key=await RoomStreamToken.parse(store, room_key),
+ presence_key=int(presence_key),
+ typing_key=int(typing_key),
+ receipt_key=await MultiWriterStreamToken.parse(store, receipt_key),
+ account_data_key=int(account_data_key),
+ push_rules_key=int(push_rules_key),
+ to_device_key=int(to_device_key),
+ device_list_key=int(device_list_key),
+ groups_key=int(groups_key),
+ un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
)
except CancelledError:
raise
@@ -813,7 +922,7 @@ class StreamToken:
await self.room_key.to_string(store),
str(self.presence_key),
str(self.typing_key),
- str(self.receipt_key),
+ await self.receipt_key.to_string(store),
str(self.account_data_key),
str(self.push_rules_key),
str(self.to_device_key),
@@ -841,6 +950,11 @@ class StreamToken:
StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
)
return new_token
+ elif key == StreamKeyType.RECEIPT:
+ new_token = self.copy_and_replace(
+ StreamKeyType.RECEIPT, self.receipt_key.copy_and_advance(new_value)
+ )
+ return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = new_token.get_field(key)
@@ -859,6 +973,10 @@ class StreamToken:
...
@overload
+ def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken:
+ ...
+
+ @overload
def get_field(
self,
key: Literal[
@@ -866,7 +984,6 @@ class StreamToken:
StreamKeyType.DEVICE_LIST,
StreamKeyType.PRESENCE,
StreamKeyType.PUSH_RULES,
- StreamKeyType.RECEIPT,
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
@@ -875,15 +992,21 @@ class StreamToken:
...
@overload
- def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ def get_field(
+ self, key: StreamKeyType
+ ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
...
- def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ def get_field(
+ self, key: StreamKeyType
+ ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)
-StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(
+ RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
+)
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 46771a401b..26b46be5e1 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -13,7 +13,7 @@
# limitations under the License.
import queue
-from typing import BinaryIO, Optional, Union, cast
+from typing import Any, BinaryIO, Optional, Union, cast
from twisted.internet import threads
from twisted.internet.defer import Deferred
@@ -58,7 +58,9 @@ class BackgroundFileConsumer:
self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
# Deferred that is resolved when finished writing
- self._finished_deferred: Optional[Deferred[None]] = None
+ #
+ # This is really Deferred[None], but mypy doesn't seem to like that.
+ self._finished_deferred: Optional[Deferred[Any]] = None
# If the _writer thread throws an exception it gets stored here.
self._write_exception: Optional[Exception] = None
@@ -80,9 +82,13 @@ class BackgroundFileConsumer:
self.streaming = streaming
self._finished_deferred = run_in_background(
threads.deferToThreadPool,
- self._reactor,
- self._reactor.getThreadPool(),
- self._writer,
+ # mypy seems to get confused with the chaining of ParamSpec from
+ # run_in_background to deferToThreadPool.
+ #
+ # For Twisted trunk, ignore arg-type; for Twisted release ignore unused-ignore.
+ self._reactor, # type: ignore[arg-type,unused-ignore]
+ self._reactor.getThreadPool(), # type: ignore[arg-type,unused-ignore]
+ self._writer, # type: ignore[arg-type,unused-ignore]
)
if not streaming:
self._producer.resumeProducing()
|