diff --git a/changelog.d/7172.misc b/changelog.d/7172.misc
new file mode 100644
index 0000000000..ffecdf97fe
--- /dev/null
+++ b/changelog.d/7172.misc
@@ -0,0 +1 @@
+Use `stream.current_token()` and remove `stream_positions()`.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 0ace7b787d..97b9b81237 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
- def stream_positions(self):
- # We must update this typing token from the response of the previous
- # sync. In particular, the stream id may "reset" back to zero/a low
- # value which we *must* use for the next replication request.
- return {"typing": self._latest_room_serial}
-
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
@@ -658,13 +652,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
)
await self.process_and_notify(stream_name, token, rows)
- def get_streams_to_replicate(self):
- args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
- args.update(self.typing_handler.stream_positions())
- if self.send_handler:
- args.update(self.send_handler.stream_positions())
- return args
-
async def process_and_notify(self, stream_name, token, rows):
try:
if self.send_handler:
@@ -799,9 +786,6 @@ class FederationSenderHandler(object):
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)
- def stream_positions(self):
- return {"federation": self.federation_position}
-
async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 751c799d94..5d7c8871a4 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Optional
+from typing import Optional
import six
@@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self.hs = hs
- def stream_positions(self) -> Dict[str, int]:
- """
- Get the current positions of all the streams this store wants to subscribe to
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- pos = {}
- if self._cache_id_gen:
- pos["caches"] = self._cache_id_gen.get_current_token()
- return pos
-
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index ebe94909cb..65e54b1c71 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedAccountDataStore, self).stream_positions()
- position = self._account_data_id_gen.get_current_token()
- result["user_account_data"] = position
- result["room_account_data"] = position
- result["tag_account_data"] = position
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 0c237c6e0f..c923751e50 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
- def stream_positions(self):
- result = super(SlavedDeviceInboxStore, self).stream_positions()
- result["to_device"] = self._device_inbox_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 23b1650e41..58fb0eaae3 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
- def stream_positions(self):
- result = super(SlavedDeviceStore, self).stream_positions()
- # The user signature stream uses the same stream ID generator as the
- # device list stream, so set them both to the device list ID
- # generator's current token.
- current_token = self._device_list_id_gen.get_current_token()
- result[DeviceListsStream.NAME] = current_token
- result[UserSignatureStream.NAME] = current_token
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index e73342c657..15011259df 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,12 +93,6 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedEventStore, self).stream_positions()
- result["events"] = self._stream_id_gen.get_current_token()
- result["backfill"] = -self._backfill_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 2d4fd08cf5..01bcf0e882 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedGroupServerStore, self).stream_positions()
- result["groups"] = self._group_updates_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index ad8f0c15a9..fae3125072 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedPresenceStore, self).stream_positions()
-
- if self.hs.config.use_presence:
- position = self._presence_id_gen.get_current_token()
- result["presence"] = position
-
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index eebd5a1fb6..6138796da4 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedPushRuleStore, self).stream_positions()
- result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index bce8a3d115..67be337945 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
- def stream_positions(self):
- result = super(SlavedPusherStore, self).stream_positions()
- result["pushers"] = self._pushers_id_gen.get_current_token()
- return result
-
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index d40dc6e1f5..993432edcb 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
- def stream_positions(self):
- result = super(SlavedReceiptsStore, self).stream_positions()
- result["receipts"] = self._receipts_id_gen.get_current_token()
- return result
-
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 3a20f45316..10dda8708f 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def stream_positions(self):
- result = super(RoomStore, self).stream_positions()
- result["public_rooms"] = self._public_room_id_gen.get_current_token()
- return result
-
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d07b8b2d0..5c28fd4ac3 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
"""
import logging
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING
from twisted.internet.protocol import ReconnectingClientFactory
@@ -100,23 +100,6 @@ class ReplicationDataHandler:
"""
self.store.process_replication_rows(stream_name, token, rows)
- def get_streams_to_replicate(self) -> Dict[str, int]:
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- args = self.store.stream_positions()
- user_account_data = args.pop("user_account_data", None)
- room_account_data = args.pop("room_account_data", None)
- if user_account_data:
- args["account_data"] = user_account_data
- elif room_account_data:
- args["account_data"] = room_account_data
- return args
-
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 6f7054d5af..d72f3d0cf9 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -314,15 +314,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(cmd.stream_name, [])
# Find where we previously streamed up to.
- current_token = self._replication_data_handler.get_streams_to_replicate().get(
- cmd.stream_name
- )
- if current_token is None:
- logger.warning(
- "Got POSITION for stream we're not subscribed to: %s",
- cmd.stream_name,
- )
- return
+ current_token = stream.current_token()
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 83e16cfe3d..8c104f8d1d 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import attr
@@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
-from synapse.app.generic_worker import GenericWorkerServer
+from synapse.app.generic_worker import (
+ GenericWorkerReplicationHandler,
+ GenericWorkerServer,
+)
from synapse.http.site import SynapseRequest
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
def _build_replication_data_handler(self):
- return TestReplicationDataHandler(self.worker_hs.get_datastore())
+ return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
if self._client_transport:
@@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
-class TestReplicationDataHandler(ReplicationDataHandler):
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
- def __init__(self, store: BaseSlavedStore):
- super().__init__(store)
-
- # streams to subscribe to: map from stream id to position
- self.stream_positions = {} # type: Dict[str, int]
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
- def get_streams_to_replicate(self):
- return self.stream_positions
-
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
- if (
- stream_name in self.stream_positions
- and token > self.stream_positions[stream_name]
- ):
- self.stream_positions[stream_name] = token
-
@attr.s()
class OneShotRequestFactory:
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 1fa28084f9..8bd67bb9f1 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.user_tok = self.login("u1", "pass")
self.reconnect()
- self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
@@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
- # we should have received all the expected rows in the right order
- received_rows = self.test_handler.received_rdata_rows
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
@@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
- # now we should have received all the expected rows in the right order.
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
#
# we expect:
#
@@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# of the states that got reverted.
# - two rows for state2
- received_rows = self.test_handler.received_rdata_rows
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
# first check the first two rows, which should be state1
@@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
- # we should have received all the expected rows in the right order
-
- received_rows = self.test_handler.received_rdata_rows
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index c122b8589c..df332ee679 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
self.reconnect()
- # make the client subscribe to the receipts stream
- self.test_handler.stream_positions.update({"receipts": 0})
-
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 4d354a9db8..e8d17ca68a 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect()
- # make the client subscribe to the typing stream
- self.test_handler.stream_positions.update({"typing": 0})
-
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0)
|