summary refs log tree commit diff
path: root/synapse/replication/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/http')
-rw-r--r--synapse/replication/http/_base.py12
-rw-r--r--synapse/replication/http/devices.py78
-rw-r--r--synapse/replication/http/register.py18
3 files changed, 102 insertions, 6 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index acb0bd18f7..3f4d3fc51a 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -153,7 +153,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         argument list.
 
         Returns:
-            dict: If POST/PUT request then dictionary must be JSON serialisable,
+            If POST/PUT request then dictionary must be JSON serialisable,
             otherwise must be appropriate for adding as query args.
         """
         return {}
@@ -184,8 +184,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         client = hs.get_simple_http_client()
         local_instance_name = hs.get_instance_name()
 
+        # The value of these option should match the replication listener settings
         master_host = hs.config.worker.worker_replication_host
         master_port = hs.config.worker.worker_replication_http_port
+        master_tls = hs.config.worker.worker_replication_http_tls
 
         instance_map = hs.config.worker.instance_map
 
@@ -205,9 +207,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 if instance_name == "master":
                     host = master_host
                     port = master_port
+                    tls = master_tls
                 elif instance_name in instance_map:
                     host = instance_map[instance_name].host
                     port = instance_map[instance_name].port
+                    tls = instance_map[instance_name].tls
                 else:
                     raise Exception(
                         "Instance %r not in 'instance_map' config" % (instance_name,)
@@ -238,7 +242,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                         "Unknown METHOD on %s replication endpoint" % (cls.NAME,)
                     )
 
-                uri = "http://%s:%s/_synapse/replication/%s/%s" % (
+                # Here the protocol is hard coded to be http by default or https in case the replication
+                # port is set to have tls true.
+                scheme = "https" if tls else "http"
+                uri = "%s://%s:%s/_synapse/replication/%s/%s" % (
+                    scheme,
                     host,
                     port,
                     cls.NAME,
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 3d63645726..7c4941c3d3 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -13,11 +13,12 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 
@@ -62,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.device_list_updater = hs.get_device_handler().device_list_updater
+        from synapse.handlers.device import DeviceHandler
+
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_list_updater = handler.device_list_updater
+
         self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
 
@@ -72,11 +78,77 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
 
     async def _handle_request(  # type: ignore[override]
         self, request: Request, user_id: str
-    ) -> Tuple[int, JsonDict]:
+    ) -> Tuple[int, Optional[JsonDict]]:
         user_devices = await self.device_list_updater.user_device_resync(user_id)
 
         return 200, user_devices
 
 
+class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
+    """Ask master to upload keys for the user and send them out over federation to
+    update other servers.
+
+    For now, only the master is permitted to handle key upload requests;
+    any worker can handle key query requests (since they're read-only).
+
+    Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on
+    the main process to accomplish this.
+
+    Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
+    Request format(borrowed and expanded from KeyUploadServlet):
+
+        POST /_synapse/replication/upload_keys_for_user
+
+    {
+        "user_id": "<user_id>",
+        "device_id": "<device_id>",
+        "keys": {
+            ....this part can be found in KeyUploadServlet in rest/client/keys.py....
+        }
+    }
+
+    Response is equivalent to ` /_matrix/client/v3/keys/upload` found in KeyUploadServlet
+
+    """
+
+    NAME = "upload_keys_for_user"
+    PATH_ARGS = ()
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+        self.store = hs.get_datastores().main
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(  # type: ignore[override]
+        user_id: str, device_id: str, keys: JsonDict
+    ) -> JsonDict:
+
+        return {
+            "user_id": user_id,
+            "device_id": device_id,
+            "keys": keys,
+        }
+
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request
+    ) -> Tuple[int, JsonDict]:
+        content = parse_json_object_from_request(request)
+
+        user_id = content["user_id"]
+        device_id = content["device_id"]
+        keys = content["keys"]
+
+        results = await self.e2e_keys_handler.upload_keys_for_user(
+            user_id, device_id, keys
+        )
+
+        return 200, results
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
+    ReplicationUploadKeysForUserRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 61abb529c8..976c283360 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -39,6 +39,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
         self.store = hs.get_datastores().main
         self.registration_handler = hs.get_registration_handler()
 
+        # Default value if the worker that sent the replication request did not include
+        # an 'approved' property.
+        if (
+            hs.config.experimental.msc3866.enabled
+            and hs.config.experimental.msc3866.require_approval_for_new_accounts
+        ):
+            self._approval_default = False
+        else:
+            self._approval_default = True
+
     @staticmethod
     async def _serialize_payload(  # type: ignore[override]
         user_id: str,
@@ -92,6 +102,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
 
         await self.registration_handler.check_registration_ratelimit(content["address"])
 
+        # Always default admin users to approved (since it means they were created by
+        # an admin).
+        approved_default = self._approval_default
+        if content["admin"]:
+            approved_default = True
+
         await self.registration_handler.register_with_store(
             user_id=user_id,
             password_hash=content["password_hash"],
@@ -103,7 +119,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             user_type=content["user_type"],
             address=content["address"],
             shadow_banned=content["shadow_banned"],
-            approved=content["approved"],
+            approved=content.get("approved", approved_default),
         )
 
         return 200, {}