summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py14
-rw-r--r--synapse/replication/slave/storage/account_data.py4
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/slave/storage/devices.py7
-rw-r--r--synapse/replication/slave/storage/groups.py2
-rw-r--r--synapse/replication/slave/storage/presence.py2
-rw-r--r--synapse/replication/slave/storage/push_rule.py12
-rw-r--r--synapse/replication/slave/storage/pushers.py2
-rw-r--r--synapse/replication/slave/storage/receipts.py2
-rw-r--r--synapse/replication/slave/storage/room.py2
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/protocol.py2
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py6
-rw-r--r--synapse/replication/tcp/streams/events.py4
18 files changed, 48 insertions, 39 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py

index 6a28c2db9d..ba16f22c91 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class ReplicationEndpoint(object): +class ReplicationEndpoint: """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index ce9420aa69..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..eb74903d68 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -16,14 +16,14 @@ from synapse.storage.util.id_generators import _load_current_id -class SlavedIdTracker(object): +class SlavedIdTracker: def __init__(self, db_conn, table, column, extra_tables=[], step=1): self.step = step self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self.advance(_load_current_id(db_conn, table, column)) + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, new_id): + def advance(self, instance_name, new_id): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): @@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 154f0e687c..bb66ba9b80 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index ee7f69a918..533d927701 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(token) + self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 722f3745e9..3b788c9625 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py
@@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 3291558c7a..567b4a5cc1 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == GroupServerStream.NAME: - self._group_updates_id_gen.advance(token) + self._group_updates_id_gen.advance(instance_name, token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index a912c04360..025f6f6be8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(token) + self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..de904c943c 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -21,18 +22,15 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_push_rules_stream_token(self): - return ( - self._push_rules_stream_id_gen.get_current_token(), - self._stream_id_gen.get_current_token(), - ) - def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): + # We assert this for the benefit of mypy + assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) + if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(token) + self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 63300e5da6..9da218bfe8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(token) + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 17ba1f22ac..5c2986e050 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(token) + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 427c81772b..80ae803ad9 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(token) + self._public_room_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index fcf8ebf1e7..d6ecf5b327 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@ # limitations under the License. """A replication client for use by synapse workers. """ -import heapq import logging from typing import TYPE_CHECKING, Dict, List, Tuple @@ -219,9 +218,8 @@ class ReplicationDataHandler: waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - # We insert into the list using heapq as it is more efficient than - # pushing then resorting each time. - heapq.heappush(waiting_list, (position, deferred)) + waiting_list.append((position, deferred)) + waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index d853e4447e..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -21,9 +21,7 @@ import abc import logging from typing import Tuple, Type -from canonicaljson import json - -from synapse.util import json_encoder as _json_encoder +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -134,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -359,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -367,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0350923898..0b0d204e64 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -113,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER -class ConnectionStates(object): +class ConnectionStates: CONNECTING = "connecting" ESTABLISHED = "established" PAUSED = "paused" diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 41569305df..04d894fb3d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): ) -class ReplicationStreamer(object): +class ReplicationStreamer: """Handles replication connections. This needs to be poked when new replication data may be available. When new diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..682d47f402 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] -class Stream(object): +class Stream: """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last @@ -352,7 +352,7 @@ class PushRulesStream(Stream): ) def _current_token(self, instance_name: str) -> int: - push_rules_token, _ = self.store.get_push_rules_stream_token() + push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token @@ -405,7 +405,7 @@ class CachesStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, + store.get_cache_stream_token_for_writer, store.get_all_updated_caches, ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..f929fc3954 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -49,14 +49,14 @@ data part are: @attr.s(slots=True, frozen=True) -class EventsStreamRow(object): +class EventsStreamRow: """A parsed row from the events replication stream""" type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow -class BaseEventsStreamRow(object): +class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. Specifies how to identify, serialize and deserialize the different types.