summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth/__init__.py18
-rw-r--r--synapse/api/auth/internal.py29
-rw-r--r--synapse/api/auth/msc3861_delegated.py28
-rw-r--r--synapse/rest/admin/experimental_features.py3
-rw-r--r--synapse/rest/client/sync.py15
5 files changed, 87 insertions, 6 deletions
diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py
index 234dcf1ca4..d5241afe73 100644
--- a/synapse/api/auth/__init__.py
+++ b/synapse/api/auth/__init__.py
@@ -18,7 +18,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
-from typing import Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from typing_extensions import Protocol
 
@@ -28,6 +28,9 @@ from synapse.appservice import ApplicationService
 from synapse.http.site import SynapseRequest
 from synapse.types import Requester
 
+if TYPE_CHECKING:
+    from synapse.rest.admin.experimental_features import ExperimentalFeature
+
 # guests always get this device id.
 GUEST_DEVICE_ID = "guest_device"
 
@@ -87,6 +90,19 @@ class Auth(Protocol):
             AuthError if access is denied for the user in the access token
         """
 
+    async def get_user_by_req_experimental_feature(
+        self,
+        request: SynapseRequest,
+        feature: "ExperimentalFeature",
+        allow_guest: bool = False,
+        allow_expired: bool = False,
+        allow_locked: bool = False,
+    ) -> Requester:
+        """Like `get_user_by_req`, except also checks if the user has access to
+        the experimental feature. If they don't returns a 404 unrecognized
+        request.
+        """
+
     async def validate_appservice_can_control_user_id(
         self, app_service: ApplicationService, user_id: str
     ) -> None:
diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index 2878f3e6e9..9fd4db68e1 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     Codes,
     InvalidClientTokenError,
     MissingClientTokenError,
+    UnrecognizedRequestError,
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import active_span, force_tracing, start_active_span
@@ -38,8 +39,10 @@ from . import GUEST_DEVICE_ID
 from .base import BaseAuth
 
 if TYPE_CHECKING:
+    from synapse.rest.admin.experimental_features import ExperimentalFeature
     from synapse.server import HomeServer
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -106,6 +109,32 @@ class InternalAuth(BaseAuth):
                     parent_span.set_tag("appservice_id", requester.app_service.id)
             return requester
 
+    async def get_user_by_req_experimental_feature(
+        self,
+        request: SynapseRequest,
+        feature: "ExperimentalFeature",
+        allow_guest: bool = False,
+        allow_expired: bool = False,
+        allow_locked: bool = False,
+    ) -> Requester:
+        try:
+            requester = await self.get_user_by_req(
+                request,
+                allow_guest=allow_guest,
+                allow_expired=allow_expired,
+                allow_locked=allow_locked,
+            )
+            if await self.store.is_feature_enabled(requester.user.to_string(), feature):
+                return requester
+
+            raise UnrecognizedRequestError(code=404)
+        except (AuthError, InvalidClientTokenError):
+            if feature.is_globally_enabled(self.hs.config):
+                # If its globally enabled then return the auth error
+                raise
+
+            raise UnrecognizedRequestError(code=404)
+
     @cancellable
     async def _wrapped_get_user_by_req(
         self,
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index 3146e1577c..f61b39ded7 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -40,6 +40,7 @@ from synapse.api.errors import (
     OAuthInsufficientScopeError,
     StoreError,
     SynapseError,
+    UnrecognizedRequestError,
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -48,6 +49,7 @@ from synapse.util import json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
 
 if TYPE_CHECKING:
+    from synapse.rest.admin.experimental_features import ExperimentalFeature
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -245,6 +247,32 @@ class MSC3861DelegatedAuth(BaseAuth):
 
         return requester
 
+    async def get_user_by_req_experimental_feature(
+        self,
+        request: SynapseRequest,
+        feature: "ExperimentalFeature",
+        allow_guest: bool = False,
+        allow_expired: bool = False,
+        allow_locked: bool = False,
+    ) -> Requester:
+        try:
+            requester = await self.get_user_by_req(
+                request,
+                allow_guest=allow_guest,
+                allow_expired=allow_expired,
+                allow_locked=allow_locked,
+            )
+            if await self.store.is_feature_enabled(requester.user.to_string(), feature):
+                return requester
+
+            raise UnrecognizedRequestError(code=404)
+        except (AuthError, InvalidClientTokenError):
+            if feature.is_globally_enabled(self.hs.config):
+                # If its globally enabled then return the auth error
+                raise
+
+            raise UnrecognizedRequestError(code=404)
+
     async def get_user_by_access_token(
         self,
         token: str,
diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py
index c1559c92f7..d7913896d9 100644
--- a/synapse/rest/admin/experimental_features.py
+++ b/synapse/rest/admin/experimental_features.py
@@ -42,10 +42,13 @@ class ExperimentalFeature(str, Enum):
     """
 
     MSC3881 = "msc3881"
+    MSC3575 = "msc3575"
 
     def is_globally_enabled(self, config: "HomeServerConfig") -> bool:
         if self is ExperimentalFeature.MSC3881:
             return config.experimental.msc3881_enabled
+        if self is ExperimentalFeature.MSC3575:
+            return config.experimental.msc3575_enabled
 
         assert_never(self)
 
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e52e771538..2a22bc14ec 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -53,6 +53,7 @@ from synapse.http.servlet import (
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import trace_with_opname
+from synapse.rest.admin.experimental_features import ExperimentalFeature
 from synapse.types import JsonDict, Requester, StreamToken
 from synapse.types.rest.client import SlidingSyncBody
 from synapse.util import json_decoder
@@ -673,7 +674,9 @@ class SlidingSyncE2eeRestServlet(RestServlet):
         )
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        requester = await self.auth.get_user_by_req_experimental_feature(
+            request, allow_guest=True, feature=ExperimentalFeature.MSC3575
+        )
         user = requester.user
         device_id = requester.device_id
 
@@ -873,7 +876,10 @@ class SlidingSyncRestServlet(RestServlet):
         self.event_serializer = hs.get_event_client_serializer()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        requester = await self.auth.get_user_by_req_experimental_feature(
+            request, allow_guest=True, feature=ExperimentalFeature.MSC3575
+        )
+
         user = requester.user
         device_id = requester.device_id
 
@@ -1051,6 +1057,5 @@ class SlidingSyncRestServlet(RestServlet):
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)
 
-    if hs.config.experimental.msc3575_enabled:
-        SlidingSyncRestServlet(hs).register(http_server)
-        SlidingSyncE2eeRestServlet(hs).register(http_server)
+    SlidingSyncRestServlet(hs).register(http_server)
+    SlidingSyncE2eeRestServlet(hs).register(http_server)