summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py8
-rw-r--r--synapse/replication/slave/storage/push_rule.py10
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/streams/_base.py4
5 files changed, 23 insertions, 15 deletions
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py

index ce9420aa69..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..d43eaf3a29 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..90d90833f9 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -21,16 +22,13 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_push_rules_stream_token(self): - return ( - self._push_rules_stream_id_gen.get_current_token(), - self._stream_id_gen.get_current_token(), - ) - def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): + # We assert this for the benefit of mypy + assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) + if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(token) for row in rows: diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index d853e4447e..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -21,9 +21,7 @@ import abc import logging from typing import Tuple, Type -from canonicaljson import json - -from synapse.util import json_encoder as _json_encoder +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -134,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -359,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -367,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7a42de3f7d..8c3caf30c9 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -352,7 +352,7 @@ class PushRulesStream(Stream): ) def _current_token(self, instance_name: str) -> int: - push_rules_token, _ = self.store.get_push_rules_stream_token() + push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token @@ -405,7 +405,7 @@ class CachesStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, + store.get_cache_stream_token_for_writer, store.get_all_updated_caches, )