diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 6a28c2db9d..ba16f22c91 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
-class ReplicationEndpoint(object):
+class ReplicationEndpoint:
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index ce9420aa69..a02b27474d 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
+ shadow_banned,
):
"""
Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
+ shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
+ "shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
+ shadow_banned=content["shadow_banned"],
)
return 200, {}
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..eb74903d68 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -16,14 +16,14 @@
from synapse.storage.util.id_generators import _load_current_id
-class SlavedIdTracker(object):
+class SlavedIdTracker:
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
- self.advance(_load_current_id(db_conn, table, column))
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, new_id):
+ def advance(self, instance_name, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
int
"""
return self._current
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to
+ `get_current_token`.
+ """
+ return self.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 154f0e687c..bb66ba9b80 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(token)
+ self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(token)
+ self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index ee7f69a918..533d927701 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
- self._device_inbox_id_gen.advance(token)
+ self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 722f3745e9..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(token)
+ self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(token)
+ self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 3291558c7a..567b4a5cc1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == GroupServerStream.NAME:
- self._group_updates_id_gen.advance(token)
+ self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index a912c04360..025f6f6be8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PresenceStream.NAME:
- self._presence_id_gen.advance(token)
+ self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..de904c943c 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -21,18 +22,15 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def get_push_rules_stream_token(self):
- return (
- self._push_rules_stream_id_gen.get_current_token(),
- self._stream_id_gen.get_current_token(),
- )
-
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
+ # We assert this for the benefit of mypy
+ assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
if stream_name == PushRulesStream.NAME:
- self._push_rules_stream_id_gen.advance(token)
+ self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 63300e5da6..9da218bfe8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(token)
+ self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 17ba1f22ac..5c2986e050 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
- self._receipts_id_gen.advance(token)
+ self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 427c81772b..80ae803ad9 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PublicRoomsStream.NAME:
- self._public_room_id_gen.advance(token)
+ self._public_room_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index fcf8ebf1e7..d6ecf5b327 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-import heapq
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple
@@ -219,9 +218,8 @@ class ReplicationDataHandler:
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
- # We insert into the list using heapq as it is more efficient than
- # pushing then resorting each time.
- heapq.heappush(waiting_list, (position, deferred))
+ waiting_list.append((position, deferred))
+ waiting_list.sort(key=lambda t: t[0])
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index d853e4447e..8cd47770c1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -21,9 +21,7 @@ import abc
import logging
from typing import Tuple, Type
-from canonicaljson import json
-
-from synapse.util import json_encoder as _json_encoder
+from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class RdataCommand(Command):
stream_name,
instance_name,
None if token == "batch" else int(token),
- json.loads(row_json),
+ json_decoder.decode(row_json),
)
def to_line(self):
@@ -134,7 +132,7 @@ class RdataCommand(Command):
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
- _json_encoder.encode(self.row),
+ json_encoder.encode(self.row),
)
)
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
return (
self.user_id
+ " "
- + _json_encoder.encode(
+ + json_encoder.encode(
(
self.access_token,
self.ip,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0350923898..0b0d204e64 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -113,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5
PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
-class ConnectionStates(object):
+class ConnectionStates:
CONNECTING = "connecting"
ESTABLISHED = "established"
PAUSED = "paused"
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 41569305df..04d894fb3d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
)
-class ReplicationStreamer(object):
+class ReplicationStreamer:
"""Handles replication connections.
This needs to be poked when new replication data may be available. When new
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..682d47f402 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
-class Stream(object):
+class Stream:
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
)
def _current_token(self, instance_name: str) -> int:
- push_rules_token, _ = self.store.get_push_rules_stream_token()
+ push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
@@ -405,7 +405,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- store.get_cache_stream_token,
+ store.get_cache_stream_token_for_writer,
store.get_all_updated_caches,
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..f929fc3954 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -49,14 +49,14 @@ data part are:
@attr.s(slots=True, frozen=True)
-class EventsStreamRow(object):
+class EventsStreamRow:
"""A parsed row from the events replication stream"""
type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
data = attr.ib() # BaseEventsStreamRow
-class BaseEventsStreamRow(object):
+class BaseEventsStreamRow:
"""Base class for rows to be sent in the events stream.
Specifies how to identify, serialize and deserialize the different types.
|