diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index acb0bd18f7..3f4d3fc51a 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -153,7 +153,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
argument list.
Returns:
- dict: If POST/PUT request then dictionary must be JSON serialisable,
+ If POST/PUT request then dictionary must be JSON serialisable,
otherwise must be appropriate for adding as query args.
"""
return {}
@@ -184,8 +184,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
client = hs.get_simple_http_client()
local_instance_name = hs.get_instance_name()
+ # The value of these option should match the replication listener settings
master_host = hs.config.worker.worker_replication_host
master_port = hs.config.worker.worker_replication_http_port
+ master_tls = hs.config.worker.worker_replication_http_tls
instance_map = hs.config.worker.instance_map
@@ -205,9 +207,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if instance_name == "master":
host = master_host
port = master_port
+ tls = master_tls
elif instance_name in instance_map:
host = instance_map[instance_name].host
port = instance_map[instance_name].port
+ tls = instance_map[instance_name].tls
else:
raise Exception(
"Instance %r not in 'instance_map' config" % (instance_name,)
@@ -238,7 +242,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
- uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+ # Here the protocol is hard coded to be http by default or https in case the replication
+ # port is set to have tls true.
+ scheme = "https" if tls else "http"
+ uri = "%s://%s:%s/_synapse/replication/%s/%s" % (
+ scheme,
host,
port,
cls.NAME,
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 3d63645726..7c4941c3d3 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,11 +13,12 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
@@ -62,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.device_list_updater = hs.get_device_handler().device_list_updater
+ from synapse.handlers.device import DeviceHandler
+
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_list_updater = handler.device_list_updater
+
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -72,11 +78,77 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
- ) -> Tuple[int, JsonDict]:
+ ) -> Tuple[int, Optional[JsonDict]]:
user_devices = await self.device_list_updater.user_device_resync(user_id)
return 200, user_devices
+class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
+ """Ask master to upload keys for the user and send them out over federation to
+ update other servers.
+
+ For now, only the master is permitted to handle key upload requests;
+ any worker can handle key query requests (since they're read-only).
+
+ Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on
+ the main process to accomplish this.
+
+ Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
+ Request format(borrowed and expanded from KeyUploadServlet):
+
+ POST /_synapse/replication/upload_keys_for_user
+
+ {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "keys": {
+ ....this part can be found in KeyUploadServlet in rest/client/keys.py....
+ }
+ }
+
+ Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet
+
+ """
+
+ NAME = "upload_keys_for_user"
+ PATH_ARGS = ()
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ self.store = hs.get_datastores().main
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ async def _serialize_payload( # type: ignore[override]
+ user_id: str, device_id: str, keys: JsonDict
+ ) -> JsonDict:
+
+ return {
+ "user_id": user_id,
+ "device_id": device_id,
+ "keys": keys,
+ }
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request
+ ) -> Tuple[int, JsonDict]:
+ content = parse_json_object_from_request(request)
+
+ user_id = content["user_id"]
+ device_id = content["device_id"]
+ keys = content["keys"]
+
+ results = await self.e2e_keys_handler.upload_keys_for_user(
+ user_id, device_id, keys
+ )
+
+ return 200, results
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
+ ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 61abb529c8..976c283360 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -39,6 +39,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.store = hs.get_datastores().main
self.registration_handler = hs.get_registration_handler()
+ # Default value if the worker that sent the replication request did not include
+ # an 'approved' property.
+ if (
+ hs.config.experimental.msc3866.enabled
+ and hs.config.experimental.msc3866.require_approval_for_new_accounts
+ ):
+ self._approval_default = False
+ else:
+ self._approval_default = True
+
@staticmethod
async def _serialize_payload( # type: ignore[override]
user_id: str,
@@ -92,6 +102,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
await self.registration_handler.check_registration_ratelimit(content["address"])
+ # Always default admin users to approved (since it means they were created by
+ # an admin).
+ approved_default = self._approval_default
+ if content["admin"]:
+ approved_default = True
+
await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],
@@ -103,7 +119,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type=content["user_type"],
address=content["address"],
shadow_banned=content["shadow_banned"],
- approved=content["approved"],
+ approved=content.get("approved", approved_default),
)
return 200, {}
diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py
deleted file mode 100644
index f43a360a80..0000000000
--- a/synapse/replication/slave/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py
deleted file mode 100644
index f43a360a80..0000000000
--- a/synapse/replication/slave/storage/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
deleted file mode 100644
index 8f3f953ed4..0000000000
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import List, Optional, Tuple
-
-from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
-
-
-class SlavedIdTracker(AbstractStreamIdTracker):
- """Tracks the "current" stream ID of a stream with a single writer.
-
- See `AbstractStreamIdTracker` for more details.
-
- Note that this class does not work correctly when there are multiple
- writers.
- """
-
- def __init__(
- self,
- db_conn: LoggingDatabaseConnection,
- table: str,
- column: str,
- extra_tables: Optional[List[Tuple[str, str]]] = None,
- step: int = 1,
- ):
- self.step = step
- self._current = _load_current_id(db_conn, table, column, step)
- if extra_tables:
- for table, column in extra_tables:
- self.advance(None, _load_current_id(db_conn, table, column))
-
- def advance(self, instance_name: Optional[str], new_id: int) -> None:
- self._current = (max if self.step > 0 else min)(self._current, new_id)
-
- def get_current_token(self) -> int:
- return self._current
-
- def get_current_token_for_writer(self, instance_name: str) -> int:
- return self.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
deleted file mode 100644
index 6fcade510a..0000000000
--- a/synapse/replication/slave/storage/devices.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING, Any, Iterable
-
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.devices import DeviceWorkerStore
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class SlavedDeviceStore(DeviceWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- self.hs = hs
-
- self._device_list_id_gen = SlavedIdTracker(
- db_conn,
- "device_lists_stream",
- "stream_id",
- extra_tables=[
- ("user_signature_stream", "stream_id"),
- ("device_lists_outbound_pokes", "stream_id"),
- ("device_lists_changes_in_room", "stream_id"),
- ],
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_device_stream_token(self) -> int:
- return self._device_list_id_gen.get_current_token()
-
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
- ) -> None:
- if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
- self._invalidate_caches_for_devices(token, rows)
- elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
- for row in rows:
- self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
-
- def _invalidate_caches_for_devices(
- self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
- ) -> None:
- for row in rows:
- # The entities are either user IDs (starting with '@') whose devices
- # have changed, or remote servers that we need to tell about
- # changes.
- if row.entity.startswith("@"):
- self._device_list_stream_cache.entity_has_changed(row.entity, token)
- self.get_cached_devices_for_user.invalidate((row.entity,))
- self._get_cached_user_device.invalidate((row.entity,))
- self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
- else:
- self._device_list_federation_stream_cache.entity_has_changed(
- row.entity, token
- )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
deleted file mode 100644
index fe47778cb1..0000000000
--- a/synapse/replication/slave/storage/events.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-from typing import TYPE_CHECKING
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
-from synapse.storage.databases.main.event_push_actions import (
- EventPushActionsWorkerStore,
-)
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.databases.main.relations import RelationsWorkerStore
-from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.storage.databases.main.signatures import SignatureWorkerStore
-from synapse.storage.databases.main.state import StateGroupWorkerStore
-from synapse.storage.databases.main.stream import StreamWorkerStore
-from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
-
-class SlavedEventStore(
- EventFederationWorkerStore,
- RoomMemberWorkerStore,
- EventPushActionsWorkerStore,
- StreamWorkerStore,
- StateGroupWorkerStore,
- SignatureWorkerStore,
- EventsWorkerStore,
- UserErasureWorkerStore,
- RelationsWorkerStore,
-):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
- db_conn,
- "current_state_delta_stream",
- entity_column="room_id",
- stream_column="stream_id",
- max_value=events_max, # As we share the stream id with events token
- limit=1000,
- )
- self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache",
- min_curr_state_delta_id,
- prefilled_cache=curr_state_delta_prefill,
- )
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
deleted file mode 100644
index c52679cd60..0000000000
--- a/synapse/replication/slave/storage/filtering.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING
-
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.filtering import FilteringStore
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class SlavedFilteringStore(SQLBaseStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- # Filters are immutable so this cache doesn't need to be expired
- get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
deleted file mode 100644
index a00b38c512..0000000000
--- a/synapse/replication/slave/storage/keys.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.storage.databases.main.keys import KeyStore
-
-# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
-# the races it creates aren't too bad.
-
-SlavedKeyStore = KeyStore
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
deleted file mode 100644
index 5e65eaf1e0..0000000000
--- a/synapse/replication/slave/storage/push_rule.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Any, Iterable
-
-from synapse.replication.tcp.streams import PushRulesStream
-from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-
-from .events import SlavedEventStore
-
-
-class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def get_max_push_rules_stream_id(self) -> int:
- return self._push_rules_stream_id_gen.get_current_token()
-
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
- ) -> None:
- if stream_name == PushRulesStream.NAME:
- self._push_rules_stream_id_gen.advance(instance_name, token)
- for row in rows:
- self.get_push_rules_for_user.invalidate((row.user_id,))
- self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
deleted file mode 100644
index 44ed20e424..0000000000
--- a/synapse/replication/slave/storage/pushers.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Any, Iterable
-
-from synapse.replication.tcp.streams import PushersStream
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-from synapse.storage.databases.main.pusher import PusherWorkerStore
-
-from ._slaved_id_tracker import SlavedIdTracker
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class SlavedPusherStore(PusherWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
- self._pushers_id_gen = SlavedIdTracker( # type: ignore
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
-
- def get_pushers_stream_token(self) -> int:
- return self._pushers_id_gen.get_current_token()
-
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
- ) -> None:
- if stream_name == PushersStream.NAME:
- self._pushers_id_gen.advance(instance_name, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 18252a2958..b4dad47b45 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -36,12 +36,14 @@ from synapse.replication.tcp.streams import (
TagAccountDataStream,
ToDeviceStream,
TypingStream,
+ UnPartialStatedRoomStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStreamRow
from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID
from synapse.util.async_helpers import Linearizer, timeout_deferred
from synapse.util.metrics import Measure
@@ -117,6 +119,7 @@ class ReplicationDataHandler:
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler()
+ self._state_storage_controller = hs.get_storage_controllers().state
self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool()
@@ -236,6 +239,14 @@ class ReplicationDataHandler:
self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id
)
+ elif stream_name == UnPartialStatedRoomStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedRoomStreamRow)
+
+ # Wake up any tasks waiting for the room to be un-partial-stated.
+ self._state_storage_controller.notify_room_un_partial_stated(
+ row.room_id
+ )
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 7763ffb2d0..56a5c21910 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -245,7 +245,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self._parse_and_dispatch_line(line)
def _parse_and_dispatch_line(self, line: bytes) -> None:
- if line.strip() == "":
+ if line.strip() == b"":
# Ignore blank lines
return
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index b1cd55bf6f..8575666d9c 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -42,6 +42,7 @@ from synapse.replication.tcp.streams._base import (
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
STREAMS_MAP = {
stream.NAME: stream
@@ -61,6 +62,7 @@ STREAMS_MAP = {
TagAccountDataStream,
AccountDataStream,
UserSignatureStream,
+ UnPartialStatedRoomStream,
)
}
@@ -80,4 +82,5 @@ __all__ = [
"TagAccountDataStream",
"AccountDataStream",
"UserSignatureStream",
+ "UnPartialStatedRoomStream",
]
diff --git a/synapse/replication/tcp/streams/partial_state.py b/synapse/replication/tcp/streams/partial_state.py
new file mode 100644
index 0000000000..18f087ffa2
--- /dev/null
+++ b/synapse/replication/tcp/streams/partial_state.py
@@ -0,0 +1,48 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import attr
+
+from synapse.replication.tcp.streams import Stream
+from synapse.replication.tcp.streams._base import current_token_without_instance
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UnPartialStatedRoomStreamRow:
+ # ID of the room that has been un-partial-stated.
+ room_id: str
+
+
+class UnPartialStatedRoomStream(Stream):
+ """
+ Stream to notify about rooms becoming un-partial-stated;
+ that is, when the background sync finishes such that we now have full state for
+ the room.
+ """
+
+ NAME = "un_partial_stated_room"
+ ROW_TYPE = UnPartialStatedRoomStreamRow
+
+ def __init__(self, hs: "HomeServer"):
+ store = hs.get_datastores().main
+ super().__init__(
+ hs.get_instance_name(),
+ # TODO(faster_joins, multiple writers): we need to account for instance names
+ current_token_without_instance(store.get_un_partial_stated_rooms_token),
+ store.get_un_partial_stated_rooms_from_stream,
+ )
|