summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/app/generic_worker.py13
-rw-r--r--synapse/config/_base.py6
-rw-r--r--synapse/config/database.py2
-rw-r--r--synapse/config/sso.py24
-rw-r--r--synapse/groups/attestations.py19
-rw-r--r--synapse/handlers/federation.py49
-rw-r--r--synapse/replication/http/streams.py12
-rw-r--r--synapse/replication/tcp/resource.py33
-rw-r--r--synapse/replication/tcp/streams/_base.py19
-rw-r--r--synapse/replication/tcp/streams/events.py127
-rw-r--r--synapse/server.pyi5
-rw-r--r--synapse/storage/data_stores/main/events_worker.py60
12 files changed, 264 insertions, 105 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py

index 2a56fe0bd5..d125327f08 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -960,17 +960,22 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - ss = GenericWorkerServer( + hs = GenericWorkerServer( config.server_name, config=config, version_string="Synapse/" + get_version_string(synapse), ) - setup_logging(ss, config, use_worker_options=True) + setup_logging(hs, config, use_worker_options=True) + + hs.setup() + + # Ensure the replication streamer is always started in case we write to any + # streams. Will no-op if no streams can be written to by this worker. + hs.get_replication_streamer() - ss.setup() reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners + "before", "startup", _base.start, hs, config.worker_listeners ) _base.start_worker_reactor("synapse-generic-worker", config) diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index bfa9d28999..30d1050a91 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -657,6 +657,12 @@ def read_config_files(config_files): for config_file in config_files: with open(config_file) as file_stream: yaml_config = yaml.safe_load(file_stream) + + if not isinstance(yaml_config, dict): + err = "File %r is empty or doesn't parse into a key-value map. IGNORING." + print(err % (config_file,)) + continue + specified_config.update(yaml_config) if "server_name" not in specified_config: diff --git a/synapse/config/database.py b/synapse/config/database.py
index c27fef157b..5b662d1b01 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py
@@ -138,7 +138,7 @@ class DatabaseConfig(Config): database_path = config.get("database_path") if multi_database_config and database_config: - raise ConfigError("Can't specify both 'database' and 'datbases' in config") + raise ConfigError("Can't specify both 'database' and 'databases' in config") if multi_database_config: if database_path: diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 6cd37d4324..cac6bc0139 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py
@@ -113,6 +113,30 @@ class SSOConfig(Config): # # * server_name: the homeserver's name. # + # * HTML page which notifies the user that they are authenticating to confirm + # an operation on their account during the user interactive authentication + # process: 'sso_auth_confirm.html'. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * description: the operation which the user is being asked to confirm + # + # * HTML page shown after a successful user interactive authentication session: + # 'sso_auth_success.html'. + # + # Note that this page must include the JavaScript which notifies of a successful authentication + # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). + # + # This template has no additional variables. + # + # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) + # attempts to login: 'sso_account_deactivated.html'. + # + # This template has no additional variables. + # # You can see the default templates at: # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index d950a8b246..1eec3874b6 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py
@@ -37,15 +37,16 @@ An attestation is a signed blob of json that looks like: import logging import random +from typing import Tuple from signedjson.sign import sign_json from twisted.internet import defer from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id +from synapse.util.async_helpers import yieldable_gather_results logger = logging.getLogger(__name__) @@ -162,19 +163,19 @@ class GroupAttestionRenewer(object): def _start_renew_attestations(self): return run_as_background_process("renew_attestations", self._renew_attestations) - @defer.inlineCallbacks - def _renew_attestations(self): + async def _renew_attestations(self): """Called periodically to check if we need to update any of our attestations """ now = self.clock.time_msec() - rows = yield self.store.get_attestations_need_renewals( + rows = await self.store.get_attestations_need_renewals( now + UPDATE_ATTESTATION_TIME_MS ) @defer.inlineCallbacks - def _renew_attestation(group_id, user_id): + def _renew_attestation(group_user: Tuple[str, str]): + group_id, user_id = group_user try: if not self.is_mine_id(group_id): destination = get_domain_from_id(group_id) @@ -207,8 +208,6 @@ class GroupAttestionRenewer(object): "Error renewing attestation of %r in %r", user_id, group_id ) - for row in rows: - group_id = row["group_id"] - user_id = row["user_id"] - - run_in_background(_renew_attestation, group_id, user_id) + await yieldable_gather_results( + _renew_attestation, ((row["group_id"], row["user_id"]) for row in rows) + ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c7aa7acf3b..41b96c0a73 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -343,7 +343,7 @@ class FederationHandler(BaseHandler): ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list(ours.values()) # type: list[StateMap[str]] + state_maps = list(ours.values()) # type: List[StateMap[str]] # we don't need this any more, let's delete it. del ours @@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler): return None - @defer.inlineCallbacks - def get_state_for_pdu(self, room_id, event_id): + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event. """ - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) + state_groups = await self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler): if "replaces_state" in event.unsigned: prev_id = event.unsigned["replaces_state"] if prev_id != event.event_id: - prev_event = yield self.store.get_event(prev_id) + prev_event = await self.store.get_event(prev_id) results[(event.type, event.state_key)] = prev_event else: del results[(event.type, event.state_key)] @@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler): else: return [] - @defer.inlineCallbacks - def get_state_ids_for_pdu(self, room_id, event_id): + async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) + state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler): else: return [] - @defer.inlineCallbacks @log_function - def on_backfill_request(self, origin, room_id, pdu_list, limit): - in_room = yield self.auth.check_host_in_room(room_id, origin) + async def on_backfill_request( + self, origin: str, room_id: str, pdu_list: List[str], limit: int + ) -> List[EventBase]: + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") # Synapse asks for 100 events per backfill request. Do not allow more. limit = min(limit, 100) - events = yield self.store.get_backfill_events(room_id, pdu_list, limit) + events = await self.store.get_backfill_events(room_id, pdu_list, limit) - events = yield filter_events_for_server(self.storage, origin, events) + events = await filter_events_for_server(self.storage, origin, events) return events - @defer.inlineCallbacks @log_function - def get_persisted_pdu(self, origin, event_id): + async def get_persisted_pdu( + self, origin: str, event_id: str + ) -> Optional[EventBase]: """Get an event from the database for the given server. Args: - origin [str]: hostname of server which is requesting the event; we + origin: hostname of server which is requesting the event; we will check that the server is allowed to see it. - event_id [str]: id of the event being requested + event_id: id of the event being requested Returns: - Deferred[EventBase|None]: None if we know nothing about the event; - otherwise the (possibly-redacted) event. + None if we know nothing about the event; otherwise the (possibly-redacted) event. Raises: AuthError if the server is not currently in the room """ - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) if event: - in_room = yield self.auth.check_host_in_room(event.room_id, origin) + in_room = await self.auth.check_host_in_room(event.room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.storage, origin, [event]) + events = await filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler): """ # exclude the state key of the new event from the current_state in the context. if event.is_state(): - event_key = (event.type, event.state_key) + event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] else: event_key = None state_updates = { diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index ffd4c61993..f35cebc710 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py
@@ -28,7 +28,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): The API looks like: - GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 + GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10 200 OK @@ -38,6 +38,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): limited: False, } + If there are more rows than can sensibly be returned in one lump, `limited` will be + set to true, and the caller should call again with a new `from_token`. + """ NAME = "get_repl_stream_updates" @@ -52,8 +55,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): self.streams = hs.get_replication_streamer().get_streams() @staticmethod - def _serialize_payload(stream_name, from_token, upto_token, limit): - return {"from_token": from_token, "upto_token": upto_token, "limit": limit} + def _serialize_payload(stream_name, from_token, upto_token): + return {"from_token": from_token, "upto_token": upto_token} async def _handle_request(self, request, stream_name): stream = self.streams.get(stream_name) @@ -62,10 +65,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): from_token = parse_integer(request, "from_token", required=True) upto_token = parse_integer(request, "upto_token", required=True) - limit = parse_integer(request, "limit", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token, limit + from_token, upto_token ) return ( diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b2d6baa2a2..33d2f589ac 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -17,9 +17,7 @@ import logging import random -from typing import Dict - -from six import itervalues +from typing import Dict, List from prometheus_client import Counter @@ -71,29 +69,28 @@ class ReplicationStreamer(object): def __init__(self, hs): self.store = hs.get_datastore() - self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.notifier = hs.get_notifier() - self._server_notices_sender = hs.get_server_notices_sender() self._replication_torture_level = hs.config.replication_torture_level - # List of streams that clients can subscribe to. - # We only support federation stream if federation sending hase been - # disabled on the master. - self.streams = [ - stream(hs) - for stream in itervalues(STREAMS_MAP) - if stream != FederationStream or not hs.config.send_federation - ] + # Work out list of streams that this instance is the source of. + self.streams = [] # type: List[Stream] + if hs.config.worker_app is None: + for stream in STREAMS_MAP.values(): + if stream == FederationStream and hs.config.send_federation: + # We only support federation stream if federation sending + # hase been disabled on the master. + continue - self.streams_by_name = {stream.NAME: stream for stream in self.streams} + self.streams.append(stream(hs)) - self.federation_sender = None - if not hs.config.send_federation: - self.federation_sender = hs.get_federation_sender() + self.streams_by_name = {stream.NAME: stream for stream in self.streams} - self.notifier.add_replication_callback(self.on_notifier_poke) + # Only bother registering the notifier callback if we have streams to + # publish. + if self.streams: + self.notifier.add_replication_callback(self.on_notifier_poke) # Keeps track of whether we are currently checking for updates self.is_looping = False diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index a860072ccf..4ae3cffb1e 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -24,8 +24,8 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates logger = logging.getLogger(__name__) - -MAX_EVENTS_BEHIND = 500000 +# the number of rows to request from an update_function. +_STREAM_UPDATE_TARGET_ROW_COUNT = 100 # Some type aliases to make things a bit easier. @@ -56,7 +56,11 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to -# * limit: the maximum number of rows to return +# * target_row_count: a target for the number of rows to be returned. +# +# The update_function is expected to return up to _approximately_ target_row_count rows. +# If there are more updates available, it should set `limited` in the result, and +# it will be called again to get the next batch. # UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] @@ -138,7 +142,7 @@ class Stream(object): return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token, limit: int = 100 + self, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -156,7 +160,7 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, limit, + from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited @@ -193,10 +197,7 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, - from_token=from_token, - upto_token=upto_token, - limit=limit, + stream_name=stream_name, from_token=from_token, upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 051114596b..f4ebbeea89 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,12 @@ # limitations under the License. import heapq -from typing import Iterable, Tuple, Type +from collections import Iterable +from typing import List, Tuple, Type import attr -from ._base import Stream, Token, db_query_to_update_function +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,30 +118,120 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, - db_query_to_update_function(self._update_function), + self._store.get_current_events_token, self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, limit: int - ) -> Iterable[tuple]: + self, from_token: Token, current_token: Token, target_row_count: int + ) -> StreamUpdateResult: + + # the events stream merges together three separate sources: + # * new events + # * current_state changes + # * events which were previously outliers, but have now been de-outliered. + # + # The merge operation is complicated by the fact that we only have a single + # "stream token" which is supposed to indicate how far we have got through + # all three streams. It's therefore no good to return rows 1-1000 from the + # "new events" table if the state_deltas are limited to rows 1-100 by the + # target_row_count. + # + # In other words: we must pick a new upper limit, and must return *all* rows + # up to that point for each of the three sources. + # + # Start by trying to split the target_row_count up. We expect to have a + # negligible number of ex-outliers, and a rough approximation based on recent + # traffic on sw1v.org shows that there are approximately the same number of + # event rows between a given pair of stream ids as there are state + # updates, so let's split our target_row_count among those two types. The target + # is only an approximation - it doesn't matter if we end up going a bit over it. + + target_row_count //= 2 + + # now we fetch up to that many rows from the events table + event_rows = await self._store.get_all_new_forward_event_rows( - from_token, current_token, limit - ) - event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows - ) + from_token, current_token, target_row_count + ) # type: List[Tuple] + + # we rely on get_all_new_forward_event_rows strictly honouring the limit, so + # that we know it is safe to just take upper_limit = event_rows[-1][0]. + assert ( + len(event_rows) <= target_row_count + ), "get_all_new_forward_event_rows did not honour row limit" + + # if we hit the limit on event_updates, there's no point in going beyond the + # last stream_id in the batch for the other sources. + + if len(event_rows) == target_row_count: + limited = True + upper_limit = event_rows[-1][0] # type: int + else: + limited = False + upper_limit = current_token + + # next up is the state delta table state_rows = await self._store.get_all_updated_current_state_deltas( - from_token, current_token, limit - ) - state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows - ) + from_token, upper_limit, target_row_count + ) # type: List[Tuple] + + assert len(state_rows) <= target_row_count + + # there can be more than one row per stream_id in that table, so if we hit + # the limit there, we'll need to truncate the results so that we have a complete + # set of changes for all the stream IDs we include. + if len(state_rows) == target_row_count: + assert state_rows[-1][0] <= upper_limit + upper_limit = state_rows[-1][0] - 1 + + # search for the point to truncate the list + for idx in range(len(state_rows) - 1, 0, -1): + if state_rows[idx - 1][0] <= upper_limit: + state_rows = state_rows[:idx] + break + else: + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + upper_limit += 1 + state_rows = await self._store.get_all_updated_current_state_deltas( + from_token, upper_limit, limit=None + ) + + limited = True + + # finally, fetch the ex-outliers rows. We assume there are few enough of these + # not to bother with the limit. + + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + from_token, upper_limit + ) # type: List[Tuple] + + # we now need to turn the raw database rows returned into tuples suitable + # for the replication protocol (basically, we add an identifier to + # distinguish the row type). At the same time, we can limit the event_rows + # to the max stream_id from state_rows. - all_updates = heapq.merge(event_updates, state_updates) + event_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in event_rows + if stream_id <= upper_limit + ) # type: Iterable[Tuple[int, Tuple]] - return all_updates + state_updates = ( + (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) + for (stream_id, *rest) in state_rows + ) # type: Iterable[Tuple[int, Tuple]] + + ex_outliers_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in ex_outliers_rows + ) # type: Iterable[Tuple[int, Tuple]] + + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) + return updates, upper_limit, limited @classmethod def parse_row(cls, row): diff --git a/synapse/server.pyi b/synapse/server.pyi
index f1a5717028..fc5886f762 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi
@@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory class HomeServer(object): @property @@ -121,3 +122,7 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index accde349a7..bce9aa7fb8 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py
@@ -973,8 +973,18 @@ class EventsWorkerStore(SQLBaseStore): return self._stream_id_gen.get_current_token() def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) + """Returns new events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ def get_all_new_forward_event_rows(txn): sql = ( @@ -989,13 +999,26 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() + return txn.fetchall() - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_ex_outlier_stream_rows(self, last_id, current_id): + """Returns de-outliered events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" @@ -1006,15 +1029,14 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" + " ORDER BY event_stream_ordering ASC" ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - return new_event_updates + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows + "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) def get_all_new_backfill_event_rows(self, last_id, current_id, limit): @@ -1062,15 +1084,23 @@ class EventsWorkerStore(SQLBaseStore): "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, limit: Optional[int] + ): def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? + ORDER BY stream_id ASC """ - txn.execute(sql, (from_token, to_token, limit)) + params = [from_token, to_token] + + if limit is not None: + sql += "LIMIT ?" + params.append(limit) + + txn.execute(sql, params) return txn.fetchall() return self.db.runInteraction(