diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5526015ddb..6912165622 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -747,7 +747,7 @@ class PresenceHandler(object):
return False
- async def get_all_presence_updates(self, last_id, current_id):
+ async def get_all_presence_updates(self, last_id, current_id, limit):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -762,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = await self.store.get_all_presence_updates(last_id, current_id)
+ rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows
def notify_new_event(self):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 391bceb0c4..c7bc14c623 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import List
from twisted.internet import defer
@@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
- async def get_all_typing_updates(self, last_id, current_id):
+ async def get_all_typing_updates(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[dict]:
+ """Get up to `limit` typing updates between the given tokens, earliest
+ updates first.
+ """
+
if last_id == current_id:
return []
@@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
- return rows
+ return rows[:limit]
def get_current_token(self):
return self._latest_room_serial
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..6e2ebaf614 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -166,11 +166,6 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
- # First we tell the streams that they should update their
- # current tokens.
- for stream in self.streams:
- stream.advance_current_token()
-
all_streams = self.streams
if self._replication_torture_level is not None:
@@ -180,7 +175,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.upto_token:
+ if stream.last_token == stream.current_token():
continue
if self._replication_torture_level:
@@ -192,7 +187,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.upto_token,
+ stream.current_token(),
)
try:
updates, current_token = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a8b6e9df1..abf5c6c6a8 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -17,10 +17,12 @@
import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Tuple
import attr
+from synapse.types import JsonDict
+
logger = logging.getLogger(__name__)
@@ -119,13 +121,12 @@ class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
- time it was called up until the point `advance_current_token` was called.
+ time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
- _LIMITED = True # Whether the update function takes a limit
@classmethod
def parse_row(cls, row):
@@ -146,26 +147,15 @@ class Stream(object):
# The token from which we last asked for updates
self.last_token = self.current_token()
- # The token that we will get updates up to
- self.upto_token = self.current_token()
-
- def advance_current_token(self):
- """Updates `upto_token` to "now", which updates up until which point
- get_updates[_since] will fetch rows till.
- """
- self.upto_token = self.current_token()
-
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.upto_token = self.current_token()
- self.last_token = self.upto_token
+ self.last_token = self.current_token()
async def get_updates(self):
"""Gets all updates since the last time this function was called (or
- since the stream was constructed if it hadn't been called before),
- until the `upto_token`
+ since the stream was constructed if it hadn't been called before).
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
@@ -178,44 +168,45 @@ class Stream(object):
return updates, current_token
- async def get_updates_since(self, from_token):
+ async def get_updates_since(
+ self, from_token: int
+ ) -> Tuple[List[Tuple[int, JsonDict]], int]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ Resolves to a pair `(updates, new_last_token)`, where `updates` is
+ a list of `(token, row)` entries and `new_last_token` is the new
+ position in stream.
"""
+
if from_token in ("NOW", "now"):
- return [], self.upto_token
+ return [], self.current_token()
- current_token = self.upto_token
+ current_token = self.current_token()
from_token = int(from_token)
if from_token == current_token:
return [], current_token
- logger.info("get_updates_since: %s", self.__class__)
- if self._LIMITED:
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
- )
+ rows = await self.update_function(
+ from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+ )
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
- else:
- rows = await self.update_function(from_token, current_token)
+ # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
+ rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
- if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
+ if len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
+ # The update function didn't hit the limit, so we must have got all
+ # the updates to `current_token`, and can return that as our new
+ # stream position.
return updates, current_token
def current_token(self):
@@ -227,9 +218,8 @@ class Stream(object):
"""
raise NotImplementedError()
- def update_function(self, from_token, current_token, limit=None):
- """Get updates between from_token and to_token. If Stream._LIMITED is
- True then limit is provided, otherwise it's not.
+ def update_function(self, from_token, current_token, limit):
+ """Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -257,7 +247,6 @@ class BackfillStream(Stream):
class PresenceStream(Stream):
NAME = "presence"
- _LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
@@ -272,7 +261,6 @@ class PresenceStream(Stream):
class TypingStream(Stream):
NAME = "typing"
- _LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
@@ -372,7 +360,6 @@ class DeviceListsStream(Stream):
"""
NAME = "device_lists"
- _LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
@@ -462,7 +449,6 @@ class UserSignatureStream(Stream):
"""
NAME = "user_signature"
- _LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 4c19c02bbc..2d47cfd131 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -576,7 +576,7 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
async def get_all_device_list_changes_for_remotes(
- self, from_key: int, to_key: int
+ self, from_key: int, to_key: int, limit: int,
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
@@ -592,10 +592,16 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
"""
return await self.db.execute(
- "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+ "get_all_device_list_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
@cached(max_entries=10000)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 001a53f9b4..bcf746b7ef 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
@@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
- SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id
+ ORDER BY stream_id ASC
+ LIMIT ?
"""
return self.db.execute(
- "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ "get_all_user_signature_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 604c8b7ddd..dab31e0c2d 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
- for state in presence_states
+ for stream_id, state in zip(stream_orderings, presence_states)
],
)
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
- def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
- sql = (
- "SELECT stream_id, user_id, state, last_active_ts,"
- " last_federation_update_ts, last_user_sync_ts, status_msg,"
- " currently_active"
- " FROM presence_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- )
- txn.execute(sql, (last_id, current_id))
+ sql = """
+ SELECT stream_id, user_id, state, last_active_ts,
+ last_federation_update_ts, last_user_sync_ts,
+ status_msg,
+ currently_active
+ FROM presence_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
|