diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2b111847b7..e114ab7ec4 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -217,7 +217,7 @@ class SynapseHomeServer(HomeServer):
)
if name in ["media", "federation", "client"]:
- if self.config.server.enable_media_repo:
+ if self.config.media.can_load_media_repo:
media_repo = self.get_media_repository_resource()
resources.update(
{
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 1645470499..dc0e93ffa1 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -126,7 +126,7 @@ class ContentRepositoryConfig(Config):
# Only enable the media repo if either the media repo is enabled or the
# current worker app is the media repo.
if (
- self.root.server.enable_media_repo is False
+ config.get("enable_media_repo", True) is False
and config.get("worker_app") != "synapse.app.media_repository"
):
self.can_load_media_repo = False
diff --git a/synapse/config/server.py b/synapse/config/server.py
index a2b2305776..8bb97df175 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -395,12 +395,6 @@ class ServerConfig(Config):
self.presence_router_config,
) = load_module(presence_router_config, ("presence", "presence_router"))
- # whether to enable the media repository endpoints. This should be set
- # to false if the media repository is running as a separate endpoint;
- # doing so ensures that we will not run cache cleanup jobs on the
- # master, potentially causing inconsistency.
- self.enable_media_repo = config.get("enable_media_repo", True)
-
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API.
self.require_auth_for_profile_requests = config.get(
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 239967fa73..6f2988e64c 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -18,11 +18,13 @@
#
#
import logging
+from enum import Enum
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple
import attr
from immutabledict import immutabledict
+from typing_extensions import assert_never
from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership
from synapse.events import EventBase
@@ -37,7 +39,9 @@ from synapse.types import (
PersistedEventPosition,
Requester,
RoomStreamToken,
+ SlidingSyncStreamToken,
StateMap,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
@@ -343,11 +347,13 @@ class SlidingSyncHandler:
self.relations_handler = hs.get_relations_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+ self.connection_store = SlidingSyncConnectionStore()
+
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SlidingSyncConfig,
- from_token: Optional[StreamToken] = None,
+ from_token: Optional[SlidingSyncStreamToken] = None,
timeout_ms: int = 0,
) -> SlidingSyncResult:
"""
@@ -382,7 +388,7 @@ class SlidingSyncHandler:
# this returns false, it means we timed out waiting, and we should
# just return an empty response.
before_wait_ts = self.clock.time_msec()
- if not await self.notifier.wait_for_stream_token(from_token):
+ if not await self.notifier.wait_for_stream_token(from_token.stream_token):
logger.warning(
"Timed out waiting for worker to catch up. Returning empty response"
)
@@ -420,7 +426,7 @@ class SlidingSyncHandler:
sync_config.user.to_string(),
timeout_ms,
current_sync_callback,
- from_token=from_token,
+ from_token=from_token.stream_token,
)
return result
@@ -430,7 +436,7 @@ class SlidingSyncHandler:
self,
sync_config: SlidingSyncConfig,
to_token: StreamToken,
- from_token: Optional[StreamToken] = None,
+ from_token: Optional[SlidingSyncStreamToken] = None,
) -> SlidingSyncResult:
"""
Generates the response body of a Sliding Sync result, represented as a
@@ -451,6 +457,12 @@ class SlidingSyncHandler:
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
+ await self.connection_store.mark_token_seen(
+ user_id,
+ conn_id=sync_config.connection_id(),
+ from_token=from_token,
+ )
+
# Get all of the room IDs that the user should be able to see in the sync
# response
has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
@@ -463,7 +475,7 @@ class SlidingSyncHandler:
await self.get_room_membership_for_user_at_to_token(
user=sync_config.user,
to_token=to_token,
- from_token=from_token,
+ from_token=from_token.stream_token if from_token else None,
)
)
@@ -605,7 +617,7 @@ class SlidingSyncHandler:
@tag_args
async def handle_room(room_id: str) -> None:
room_sync_result = await self.get_room_sync_data(
- user=sync_config.user,
+ sync_config=sync_config,
room_id=room_id,
room_sync_config=room_sync_config,
room_membership_for_user_at_to_token=room_membership_for_user_map[
@@ -624,8 +636,21 @@ class SlidingSyncHandler:
sync_config=sync_config, to_token=to_token
)
+ if has_lists or has_room_subscriptions:
+ connection_token = await self.connection_store.record_rooms(
+ user_id,
+ conn_id=sync_config.connection_id(),
+ from_token=from_token,
+ sent_room_ids=relevant_room_map.keys(),
+ unsent_room_ids=[], # TODO: We currently ssume that we have sent down all updates.
+ )
+ elif from_token:
+ connection_token = from_token.connection_token
+ else:
+ connection_token = 0
+
return SlidingSyncResult(
- next_pos=to_token,
+ next_pos=SlidingSyncStreamToken(to_token, connection_token),
lists=lists,
rooms=rooms,
extensions=extensions,
@@ -713,10 +738,17 @@ class SlidingSyncHandler:
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
# Then assemble the `RoomStreamToken`
+ min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map`
- stream=min(instance_to_max_stream_ordering_map.values()),
- instance_map=immutabledict(instance_to_max_stream_ordering_map),
+ stream=min_stream_pos,
+ instance_map=immutabledict(
+ {
+ instance_name: stream_pos
+ for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
+ if stream_pos > min_stream_pos
+ }
+ ),
)
# Since we fetched the users room list at some point in time after the from/to
@@ -1359,11 +1391,11 @@ class SlidingSyncHandler:
async def get_room_sync_data(
self,
- user: UserID,
+ sync_config: SlidingSyncConfig,
room_id: str,
room_sync_config: RoomSyncConfig,
room_membership_for_user_at_to_token: _RoomMembershipForUser,
- from_token: Optional[StreamToken],
+ from_token: Optional[SlidingSyncStreamToken],
to_token: StreamToken,
) -> SlidingSyncResult.RoomResult:
"""
@@ -1381,6 +1413,38 @@ class SlidingSyncHandler:
from_token: The point in the stream to sync from.
to_token: The point in the stream to sync up to.
"""
+ user = sync_config.user
+
+ # Determine whether we should limit the timeline to the token range.
+ #
+ # We should return historical messages (before token range) in the
+ # following cases because we want clients to be able to show a basic
+ # screen of information:
+ # - Initial sync (because no `from_token` to limit us anyway)
+ # - When users `newly_joined`
+ # - For an incremental sync where we haven't sent it down this
+ # connection before
+ to_bound = None
+ initial = True
+ if from_token and not room_membership_for_user_at_to_token.newly_joined:
+ room_status = await self.connection_store.have_sent_room(
+ user_id=user.to_string(),
+ conn_id=sync_config.connection_id(),
+ connection_token=from_token.connection_token,
+ room_id=room_id,
+ )
+ if room_status.status == HaveSentRoomFlag.LIVE:
+ to_bound = from_token.stream_token.room_key
+ initial = False
+ elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+ assert room_status.last_token is not None
+ to_bound = room_status.last_token
+ initial = False
+ elif room_status.status == HaveSentRoomFlag.NEVER:
+ to_bound = None
+ initial = True
+ else:
+ assert_never(room_status.status)
# Assemble the list of timeline events
#
@@ -1417,22 +1481,6 @@ class SlidingSyncHandler:
room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
)
- # Determine whether we should limit the timeline to the token range.
- #
- # We should return historical messages (before token range) in the
- # following cases because we want clients to be able to show a basic
- # screen of information:
- # - Initial sync (because no `from_token` to limit us anyway)
- # - When users `newly_joined`
- # - TODO: For an incremental sync where we haven't sent it down this
- # connection before
- to_bound = (
- from_token.room_key
- if from_token is not None
- and not room_membership_for_user_at_to_token.newly_joined
- else None
- )
-
fiddled_timeline_limit = room_sync_config.timeline_limit
if to_bound:
fiddled_timeline_limit = max(fiddled_timeline_limit, 10)
@@ -1498,7 +1546,9 @@ class SlidingSyncHandler:
instance_name=timeline_event.internal_metadata.instance_name,
stream=timeline_event.internal_metadata.stream_ordering,
)
- if persisted_position.persisted_after(from_token.room_key):
+ if persisted_position.persisted_after(
+ from_token.stream_token.room_key
+ ):
num_live += 1
else:
# Since we're iterating over the timeline events in
@@ -1555,12 +1605,6 @@ class SlidingSyncHandler:
# indicate to the client that a state reset happened. Perhaps we should indicate
# this by setting `initial: True` and empty `required_state`.
- # TODO: Since we can't determine whether we've already sent a room down this
- # Sliding Sync connection before (we plan to add this optimization in the
- # future), we're always returning the requested room state instead of
- # updates.
- initial = True
-
# Check whether the room has a name set
name_state_ids = await self.get_current_state_ids_at(
room_id=room_id,
@@ -1711,9 +1755,17 @@ class SlidingSyncHandler:
to_token=to_token,
)
else:
- # TODO: Once we can figure out if we've sent a room down this connection before,
- # we can return updates instead of the full required state.
- raise NotImplementedError()
+ assert to_bound is not None
+
+ deltas = await self.store.get_current_state_deltas_for_room(
+ room_id, to_bound, to_token.room_key
+ )
+ # TODO: Filter room state before fetching events
+ # TODO: Handle state resets where event_id is None
+ events = await self.store.get_events(
+ [d.event_id for d in deltas if d.event_id]
+ )
+ room_state = {(s.type, s.state_key): s for s in events.values()}
required_room_state: StateMap[EventBase] = {}
if required_state_filter != StateFilter.none():
@@ -1841,7 +1893,7 @@ class SlidingSyncHandler:
"""
user_id = sync_config.user.to_string()
- device_id = sync_config.device_id
+ device_id = sync_config.requester.device_id
# Check that this request has a valid device ID (not all requests have
# to belong to a device, and so device_id is None), and that the
@@ -1898,3 +1950,198 @@ class SlidingSyncHandler:
next_batch=f"{stream_id}",
events=messages,
)
+
+
+class HaveSentRoomFlag(Enum):
+ """Flag for whether we have sent the room down a sliding sync connection.
+
+ The valid state changes here are:
+ NEVER -> LIVE
+ LIVE -> PREVIOUSLY
+ PREVIOUSLY -> LIVE
+ """
+
+ # The room has never been sent down (or we have forgotten we have sent it
+ # down).
+ NEVER = 1
+
+ # We have previously sent the room down, but there are updates that we
+ # haven't sent down.
+ PREVIOUSLY = 2
+
+ # We have sent the room down and the client has received all updates.
+ LIVE = 3
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class HaveSentRoom:
+ """Whether we have sent the room down a sliding sync connection.
+
+ Attributes:
+ status: Flag of if we have or haven't sent down the room
+ last_token: If the flag is `PREVIOUSLY` then this is non-null and
+ contains the last stream token of the last updates we sent down
+ the room, i.e. we still need to send everything since then to the
+ client.
+ """
+
+ status: HaveSentRoomFlag
+ last_token: Optional[RoomStreamToken]
+
+ @staticmethod
+ def previously(last_token: RoomStreamToken) -> "HaveSentRoom":
+ """Constructor for `PREVIOUSLY` flag."""
+ return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
+
+
+HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None)
+HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None)
+
+
+@attr.s(auto_attribs=True)
+class SlidingSyncConnectionStore:
+ """In-memory store of per-connection state, including what rooms we have
+ previously sent down a sliding sync connection.
+
+ Note: This is NOT safe to run in a worker setup.
+
+ The complication here is that we need to handle requests being resent, i.e.
+ if we sent down a room in a response that the client received, we must
+ consider the room *not* sent when we get the request again.
+
+ This is handled by using an integer "token", which is returned to the client
+ as part of the sync token. For each connection we store a mapping from
+ tokens to the room states, and create a new entry when we send down new
+ rooms.
+
+ Note that for any given sliding sync connection we will only store a maximum
+ of two different tokens: the previous token from the request and a new token
+ sent in the response. When we receive a request with a given token, we then
+ clear out all other entries with a different token.
+
+ Attributes:
+ _connections: Mapping from `(user_id, conn_id)` to mapping of `token`
+ to mapping of room ID to `HaveSentRoom`.
+ """
+
+ # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom`
+ _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = (
+ attr.Factory(dict)
+ )
+
+ async def have_sent_room(
+ self, user_id: str, conn_id: str, connection_token: int, room_id: str
+ ) -> HaveSentRoom:
+ """Whether for the given user_id/conn_id/token, return whether we have
+ previously sent the room down
+ """
+
+ sync_statuses = self._connections.setdefault((user_id, conn_id), {})
+ room_status = sync_statuses.get(connection_token, {}).get(
+ room_id, HAVE_SENT_ROOM_NEVER
+ )
+
+ return room_status
+
+ async def record_rooms(
+ self,
+ user_id: str,
+ conn_id: str,
+ from_token: Optional[SlidingSyncStreamToken],
+ *,
+ sent_room_ids: StrCollection,
+ unsent_room_ids: StrCollection,
+ ) -> int:
+ """Record which rooms we have/haven't sent down in a new response
+
+ Attributes:
+ user_id
+ conn_id
+ from_token: The since token from the request, if any
+ sent_room_ids: The set of room IDs that we have sent down as
+ part of this request (only needs to be ones we didn't
+ previously sent down).
+ unsent_room_ids: The set of room IDs that have had updates
+ since the `last_room_token`, but which were not included in
+ this request
+ """
+ prev_connection_token = 0
+ if from_token is not None:
+ prev_connection_token = from_token.connection_token
+
+ # If there are no changes then this is a noop.
+ if not sent_room_ids and not unsent_room_ids:
+ return prev_connection_token
+
+ sync_statuses = self._connections.setdefault((user_id, conn_id), {})
+
+ # Generate a new token, removing any existing entries in that token
+ # (which can happen if requests get resent).
+ new_store_token = prev_connection_token + 1
+ sync_statuses.pop(new_store_token, None)
+
+ # Copy over and update the room mappings.
+ new_room_statuses = dict(sync_statuses.get(prev_connection_token, {}))
+
+ # Whether we have updated the `new_room_statuses`, if we don't by the
+ # end we can treat this as a noop.
+ have_updated = False
+ for room_id in sent_room_ids:
+ new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE
+ have_updated = True
+
+ # Whether we add/update the entries for unsent rooms depends on the
+ # existing entry:
+ # - LIVE: We have previously sent down everything up to
+ # `last_room_token, so we update the entry to be `PREVIOUSLY` with
+ # `last_room_token`.
+ # - PREVIOUSLY: We have previously sent down everything up to *a*
+ # given token, so we don't need to update the entry.
+ # - NEVER: We have never previously sent down the room, and we haven't
+ # sent anything down this time either so we leave it as NEVER.
+
+ # Work out the new state for unsent rooms that were `LIVE`.
+ if from_token:
+ new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key)
+ else:
+ new_unsent_state = HAVE_SENT_ROOM_NEVER
+
+ for room_id in unsent_room_ids:
+ prev_state = new_room_statuses.get(room_id)
+ if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE:
+ new_room_statuses[room_id] = new_unsent_state
+ have_updated = True
+
+ if not have_updated:
+ return prev_connection_token
+
+ sync_statuses[new_store_token] = new_room_statuses
+
+ return new_store_token
+
+ async def mark_token_seen(
+ self,
+ user_id: str,
+ conn_id: str,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> None:
+ """We have received a request with the given token, so we can clear out
+ any other tokens associated with the connection.
+
+ If there is no from token then we have started afresh, and so we delete
+ all tokens associated with the device.
+ """
+ # Clear out any tokens for the connection that doesn't match the one
+ # from the request.
+
+ sync_statuses = self._connections.pop((user_id, conn_id), {})
+ if from_token is None:
+ return
+
+ sync_statuses = {
+ i: room_statuses
+ for i, room_statuses in sync_statuses.items()
+ if i == from_token.connection_token
+ }
+ if sync_statuses:
+ self._connections[(user_id, conn_id)] = sync_statuses
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 749b01dd0e..6fd75fd381 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -90,7 +90,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
-from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
+from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred
from synapse.util.metrics import Measure
from synapse.util.stringutils import parse_and_validate_server_name
@@ -475,6 +475,8 @@ class MatrixFederationHttpClient:
use_proxy=True,
)
+ self.remote_download_linearizer = Linearizer("remote_download_linearizer", 6)
+
def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online."""
@@ -1486,35 +1488,44 @@ class MatrixFederationHttpClient:
)
headers = dict(response.headers.getAllRawHeaders())
-
expected_size = response.length
- # if we don't get an expected length then use the max length
+
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
- logger.debug(
- f"File size unknown, assuming file is max allowable size: {max_size}"
- )
+ else:
+ if int(expected_size) > max_size:
+ msg = "Requested file is too large > %r bytes" % (max_size,)
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
- read_body, _ = await download_ratelimiter.can_do_action(
- requester=None,
- key=ip_address,
- n_actions=expected_size,
- )
- if not read_body:
- msg = "Requested file size exceeds ratelimits"
- logger.warning(
- "{%s} [%s] %s",
- request.txn_id,
- request.destination,
- msg,
+ read_body, _ = await download_ratelimiter.can_do_action(
+ requester=None,
+ key=ip_address,
+ n_actions=expected_size,
)
- raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+ if not read_body:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(
+ HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED
+ )
try:
- # add a byte of headroom to max size as function errs at >=
- d = read_body_with_max_size(response, output_stream, expected_size + 1)
- d.addTimeout(self.default_timeout_seconds, self.reactor)
- length = await make_deferred_yieldable(d)
+ async with self.remote_download_linearizer.queue(ip_address):
+ # add a byte of headroom to max size as function errs at >=
+ d = read_body_with_max_size(response, output_stream, expected_size + 1)
+ d.addTimeout(self.default_timeout_seconds, self.reactor)
+ length = await make_deferred_yieldable(d)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
@@ -1560,6 +1571,13 @@ class MatrixFederationHttpClient:
request.method,
request.uri.decode("ascii"),
)
+
+ # if we didn't know the length upfront, decrement the actual size from ratelimiter
+ if response.length == UNKNOWN_LENGTH:
+ download_ratelimiter.record_action(
+ requester=None, key=ip_address, n_actions=length
+ )
+
return length, headers
async def federation_get_file(
@@ -1630,29 +1648,37 @@ class MatrixFederationHttpClient:
)
headers = dict(response.headers.getAllRawHeaders())
-
expected_size = response.length
- # if we don't get an expected length then use the max length
+
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
- logger.debug(
- f"File size unknown, assuming file is max allowable size: {max_size}"
- )
+ else:
+ if int(expected_size) > max_size:
+ msg = "Requested file is too large > %r bytes" % (max_size,)
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
- read_body, _ = await download_ratelimiter.can_do_action(
- requester=None,
- key=ip_address,
- n_actions=expected_size,
- )
- if not read_body:
- msg = "Requested file size exceeds ratelimits"
- logger.warning(
- "{%s} [%s] %s",
- request.txn_id,
- request.destination,
- msg,
+ read_body, _ = await download_ratelimiter.can_do_action(
+ requester=None,
+ key=ip_address,
+ n_actions=expected_size,
)
- raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+ if not read_body:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(
+ HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED
+ )
# this should be a multipart/mixed response with the boundary string in the header
try:
@@ -1672,11 +1698,12 @@ class MatrixFederationHttpClient:
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
try:
- # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
- deferred = read_multipart_response(
- response, output_stream, boundary, expected_size + 1
- )
- deferred.addTimeout(self.default_timeout_seconds, self.reactor)
+ async with self.remote_download_linearizer.queue(ip_address):
+ # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
+ deferred = read_multipart_response(
+ response, output_stream, boundary, expected_size + 1
+ )
+ deferred.addTimeout(self.default_timeout_seconds, self.reactor)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
@@ -1743,6 +1770,13 @@ class MatrixFederationHttpClient:
request.method,
request.uri.decode("ascii"),
)
+
+ # if we didn't know the length upfront, decrement the actual size from ratelimiter
+ if response.length == UNKNOWN_LENGTH:
+ download_ratelimiter.record_action(
+ requester=None, key=ip_address, n_actions=length
+ )
+
return length, headers, multipart_response.json
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index b7e2c99455..7c91b15cef 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -54,7 +54,7 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.rest.admin.experimental_features import ExperimentalFeature
-from synapse.types import JsonDict, Requester, StreamToken
+from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util import json_decoder
from synapse.util.caches.lrucache import LruCache
@@ -881,7 +881,6 @@ class SlidingSyncRestServlet(RestServlet):
)
user = requester.user
- device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
# Position in the stream
@@ -889,7 +888,9 @@ class SlidingSyncRestServlet(RestServlet):
from_token = None
if from_token_string is not None:
- from_token = await StreamToken.from_string(self.store, from_token_string)
+ from_token = await SlidingSyncStreamToken.from_string(
+ self.store, from_token_string
+ )
# TODO: We currently don't know whether we're going to use sticky params or
# maybe some filters like sync v2 where they are built up once and referenced
@@ -904,11 +905,12 @@ class SlidingSyncRestServlet(RestServlet):
sync_config = SlidingSyncConfig(
user=user,
- device_id=device_id,
+ requester=requester,
# FIXME: Currently, we're just manually copying the fields from the
- # `SlidingSyncBody` into the config. How can we gurantee into the future
+ # `SlidingSyncBody` into the config. How can we guarantee into the future
# that we don't forget any? I would like something more structured like
# `copy_attributes(from=body, to=config)`
+ conn_id=body.conn_id,
lists=body.lists,
room_subscriptions=body.room_subscriptions,
extensions=body.extensions,
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 036972ac25..cd6cb2c7a9 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -26,6 +26,8 @@ import attr
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main.stream import _filter_results_by_stream
+from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -156,3 +158,39 @@ class StateDeltasStore(SQLBaseStore):
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
+
+ async def get_current_state_deltas_for_room(
+ self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken
+ ) -> List[StateDelta]:
+ """Get the state deltas between that have happened between two
+ tokens."""
+
+ def get_current_state_deltas_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> List[StateDelta]:
+ sql = """
+ SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE room_id = ? AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(
+ sql, (room_id, from_token.stream, to_token.get_max_stream_pos())
+ )
+
+ return [
+ StateDelta(
+ stream_id=row[1],
+ room_id=room_id,
+ event_type=row[2],
+ state_key=row[3],
+ event_id=row[4],
+ prev_event_id=row[5],
+ )
+ for row in txn
+ if _filter_results_by_stream(from_token, to_token, row[0], row[1])
+ ]
+
+ return await self.db_pool.runInteraction(
+ "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn
+ )
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index b22a13ef01..23ac1842f8 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -20,6 +20,7 @@
#
#
import abc
+import logging
import re
import string
from enum import Enum
@@ -74,6 +75,9 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
+
+logger = logging.getLogger(__name__)
+
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
@@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
+
+ The values in `instance_map` must be greater than the `stream` attribute.
"""
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)
+ def __attrs_post_init__(self) -> None:
+ # Enforce that all instances have a value greater than the min stream
+ # position.
+ for i, v in self.instance_map.items():
+ if v <= self.stream:
+ raise ValueError(
+ f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
+ )
+
@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
for instance in set(self.instance_map).union(other.instance_map)
}
+ # Filter out any redundant entries.
+ instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
+
return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
@@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""
+ min_pos = min(self.stream, max_stream)
return type(self)(
- stream=min(self.stream, max_stream),
+ stream=min_pos,
instance_map=immutabledict(
- {k: min(s, max_stream) for k, s in self.instance_map.items()}
+ {
+ k: min(s, max_stream)
+ for k, s in self.instance_map.items()
+ if min(s, max_stream) > min_pos
+ }
),
)
@@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)
+ super().__attrs_post_init__()
+
@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
@@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
+ if not part:
+ # Handle tokens of the form `m5~`, which were created by
+ # a bug
+ continue
+
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
- pass
+ # We log an exception here as even though this *might* be a client
+ # handing a bad token, its more likely that Synapse returned a bad
+ # token (and we really want to catch those!).
+ logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
@@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return self.instance_map.get(instance_name, self.stream)
async def to_string(self, store: "DataStore") -> str:
+ """See class level docstring for information about the format."""
+
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
@@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
- encoded_map = "~".join(entries)
- return f"m{self.stream}~{encoded_map}"
+ if entries:
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ return f"s{self.stream}"
else:
return "s%d" % (self.stream,)
@@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
+ if not part:
+ # Handle tokens of the form `m5~`, which were created by
+ # a bug
+ continue
+
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
- pass
+ # We log an exception here as even though this *might* be a client
+ # handing a bad token, its more likely that Synapse returned a bad
+ # token (and we really want to catch those!).
+ logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))
async def to_string(self, store: "DataStore") -> str:
+ """See class level docstring for information about the format."""
+
if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
@@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
- encoded_map = "~".join(entries)
- return f"m{self.stream}~{encoded_map}"
+ if entries:
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ return str(self.stream)
else:
return str(self.stream)
@@ -1089,6 +1138,43 @@ StreamToken.START = StreamToken(
@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SlidingSyncStreamToken:
+ """The same as a `StreamToken`, but includes an extra field at the start for
+ the sliding sync connection token (separated by a '/'). This is used to
+ store per-connection state.
+
+ This then looks something like:
+ 5/s2633508_17_338_6732159_1082514_541479_274711_265584_1_379
+ """
+
+ stream_token: StreamToken
+ connection_token: int
+
+ @staticmethod
+ @cancellable
+ async def from_string(store: "DataStore", string: str) -> "SlidingSyncStreamToken":
+ """Creates a SlidingSyncStreamToken from its textual representation."""
+ try:
+ connection_token_str, stream_token_str = string.split("/", 1)
+ connection_token = int(connection_token_str)
+ stream_token = await StreamToken.from_string(store, stream_token_str)
+
+ return SlidingSyncStreamToken(
+ stream_token=stream_token,
+ connection_token=connection_token,
+ )
+ except CancelledError:
+ raise
+ except Exception:
+ raise SynapseError(400, "Invalid stream token")
+
+ async def to_string(self, store: "DataStore") -> str:
+ """Serializes the token to a string"""
+ stream_token_str = await self.stream_token.to_string(store)
+ return f"{self.connection_token}/{stream_token_str}"
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedPosition:
"""Position of a newly persisted row with instance that persisted it."""
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 409120470a..0c2ab13c93 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -31,7 +31,14 @@ else:
from pydantic import Extra
from synapse.events import EventBase
-from synapse.types import JsonDict, JsonMapping, StreamToken, UserID
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ Requester,
+ SlidingSyncStreamToken,
+ StreamToken,
+ UserID,
+)
from synapse.types.rest.client import SlidingSyncBody
if TYPE_CHECKING:
@@ -102,7 +109,7 @@ class SlidingSyncConfig(SlidingSyncBody):
"""
user: UserID
- device_id: Optional[str]
+ requester: Requester
# Pydantic config
class Config:
@@ -113,6 +120,31 @@ class SlidingSyncConfig(SlidingSyncBody):
# Allow custom types like `UserID` to be used in the model
arbitrary_types_allowed = True
+ def connection_id(self) -> str:
+ """Return a string identifier for this connection. May clash with
+ connection IDs from different users.
+
+ This is generally a combination of device ID and conn_id. However, both
+ these two are optional (e.g. puppet access tokens don't have device
+ IDs), so this handles those edge cases.
+ """
+
+ # `conn_id` can be null, in which case we default to the empty string
+ # (if conn ID is empty then the client can't have multiple sync loops)
+ conn_id = self.conn_id or ""
+
+ if self.requester.device_id:
+ return f"D/{self.requester.device_id}/{conn_id}"
+
+ if self.requester.access_token_id:
+ # If we don't have a device, then the access token ID should be a
+ # stable ID.
+ return f"A/{self.requester.access_token_id}/{conn_id}"
+
+ # If we have neither then its likely an AS or some weird token. Either
+ # way we can just fail here.
+ raise Exception("Cannot use sliding sync with access token type")
+
class OperationType(Enum):
"""
@@ -287,7 +319,7 @@ class SlidingSyncResult:
def __bool__(self) -> bool:
return bool(self.to_device)
- next_pos: StreamToken
+ next_pos: SlidingSyncStreamToken
lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult]
extensions: Extensions
@@ -300,7 +332,7 @@ class SlidingSyncResult:
return bool(self.lists or self.rooms or self.extensions)
@staticmethod
- def empty(next_pos: StreamToken) -> "SlidingSyncResult":
+ def empty(next_pos: SlidingSyncStreamToken) -> "SlidingSyncResult":
"Return a new empty result"
return SlidingSyncResult(
next_pos=next_pos,
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index dbe37bc712..5be8cf5389 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -120,6 +120,9 @@ class SlidingSyncBody(RequestBodyModel):
Sliding Sync API request body.
Attributes:
+ conn_id: An optional string to identify this connection to the server. If this
+ is missing, only 1 sliding sync connection can be made to the server at
+ any one time.
lists: Sliding window API. A map of list key to list information
(:class:`SlidingSyncList`). Max lists: 100. The list keys should be
arbitrary strings which the client is using to refer to the list. Keep this
@@ -315,6 +318,8 @@ class SlidingSyncBody(RequestBodyModel):
to_device: Optional[ToDeviceExtension] = None
+ conn_id: Optional[str]
+
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
lists: Optional[Dict[str, SlidingSyncList]] = None
|