summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/admin/experimental_features.py11
-rw-r--r--synapse/rest/client/pusher.py29
-rw-r--r--synapse/rest/client/versions.py20
-rw-r--r--synapse/storage/databases/main/experimental_features.py64
4 files changed, 98 insertions, 26 deletions
diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py
index c5a00c490c..c1559c92f7 100644
--- a/synapse/rest/admin/experimental_features.py
+++ b/synapse/rest/admin/experimental_features.py
@@ -31,7 +31,9 @@ from synapse.rest.admin import admin_patterns, assert_requester_is_admin
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
-    from synapse.server import HomeServer
+    from typing_extensions import assert_never
+
+    from synapse.server import HomeServer, HomeServerConfig
 
 
 class ExperimentalFeature(str, Enum):
@@ -39,9 +41,14 @@ class ExperimentalFeature(str, Enum):
     Currently supported per-user features
     """
 
-    MSC3026 = "msc3026"
     MSC3881 = "msc3881"
 
+    def is_globally_enabled(self, config: "HomeServerConfig") -> bool:
+        if self is ExperimentalFeature.MSC3881:
+            return config.experimental.msc3881_enabled
+
+        assert_never(self)
+
 
 class ExperimentalFeaturesRestServlet(RestServlet):
     """
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 9957d2fcbe..a455f95a26 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
 )
 from synapse.http.site import SynapseRequest
 from synapse.push import PusherConfigException
+from synapse.rest.admin.experimental_features import ExperimentalFeature
 from synapse.rest.client._base import client_patterns
 from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
 from synapse.types import JsonDict
@@ -49,20 +50,22 @@ class PushersRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
+        self._store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
 
-        pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(
-            user.to_string()
+        msc3881_enabled = await self._store.is_feature_enabled(
+            user_id, ExperimentalFeature.MSC3881
         )
 
+        pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(user_id)
+
         pusher_dicts = [p.as_dict() for p in pushers]
 
         for pusher in pusher_dicts:
-            if self._msc3881_enabled:
+            if msc3881_enabled:
                 pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
                 pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
             del pusher["enabled"]
@@ -80,11 +83,15 @@ class PushersSetRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
         self.pusher_pool = self.hs.get_pusherpool()
-        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
+        self._store = hs.get_datastores().main
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
+
+        msc3881_enabled = await self._store.is_feature_enabled(
+            user_id, ExperimentalFeature.MSC3881
+        )
 
         content = parse_json_object_from_request(request)
 
@@ -95,7 +102,7 @@ class PushersSetRestServlet(RestServlet):
             and content["kind"] is None
         ):
             await self.pusher_pool.remove_pusher(
-                content["app_id"], content["pushkey"], user_id=user.to_string()
+                content["app_id"], content["pushkey"], user_id=user_id
             )
             return 200, {}
 
@@ -120,19 +127,19 @@ class PushersSetRestServlet(RestServlet):
             append = content["append"]
 
         enabled = True
-        if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
+        if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
             enabled = content["org.matrix.msc3881.enabled"]
 
         if not append:
             await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
                 app_id=content["app_id"],
                 pushkey=content["pushkey"],
-                not_user_id=user.to_string(),
+                not_user_id=user_id,
             )
 
         try:
             await self.pusher_pool.add_or_update_pusher(
-                user_id=user.to_string(),
+                user_id=user_id,
                 kind=content["kind"],
                 app_id=content["app_id"],
                 app_display_name=content["app_display_name"],
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index f428158139..e01e5f542a 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -25,11 +25,11 @@ import logging
 import re
 from typing import TYPE_CHECKING, Tuple
 
-from twisted.web.server import Request
-
 from synapse.api.constants import RoomCreationPreset
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin.experimental_features import ExperimentalFeature
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -45,6 +45,8 @@ class VersionsRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
 
         # Calculate these once since they shouldn't change after start-up.
         self.e2ee_forced_public = (
@@ -60,7 +62,17 @@ class VersionsRestServlet(RestServlet):
             in self.config.room.encryption_enabled_by_default_for_room_presets
         )
 
-    def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        msc3881_enabled = self.config.experimental.msc3881_enabled
+
+        if self.auth.has_access_token(request):
+            requester = await self.auth.get_user_by_req(request)
+            user_id = requester.user.to_string()
+
+            msc3881_enabled = await self.store.is_feature_enabled(
+                user_id, ExperimentalFeature.MSC3881
+            )
+
         return (
             200,
             {
@@ -124,7 +136,7 @@ class VersionsRestServlet(RestServlet):
                     # TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
                     "org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
                     # Adds support for remotely enabling/disabling pushers, as per MSC3881
-                    "org.matrix.msc3881": self.config.experimental.msc3881_enabled,
+                    "org.matrix.msc3881": msc3881_enabled,
                     # Adds support for filtering /messages by event relation.
                     "org.matrix.msc3874": self.config.experimental.msc3874_enabled,
                     # Adds support for simple HTTP rendezvous as per MSC3886
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index fbb98d8f63..d980c57fa8 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -21,7 +21,11 @@
 
 from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
 
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.util.caches.descriptors import cached
 
@@ -73,12 +77,54 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
             features:
                 pairs of features and True/False for whether the feature should be enabled
         """
-        for feature, enabled in features.items():
-            await self.db_pool.simple_upsert(
-                table="per_user_experimental_features",
-                keyvalues={"feature": feature, "user_id": user},
-                values={"enabled": enabled},
-                insertion_values={"user_id": user, "feature": feature},
-            )
 
-            await self.invalidate_cache_and_stream("list_enabled_features", (user,))
+        def set_features_for_user_txn(txn: LoggingTransaction) -> None:
+            for feature, enabled in features.items():
+                self.db_pool.simple_upsert_txn(
+                    txn,
+                    table="per_user_experimental_features",
+                    keyvalues={"feature": feature, "user_id": user},
+                    values={"enabled": enabled},
+                    insertion_values={"user_id": user, "feature": feature},
+                )
+
+                self._invalidate_cache_and_stream(
+                    txn, self.is_feature_enabled, (user, feature)
+                )
+
+            self._invalidate_cache_and_stream(txn, self.list_enabled_features, (user,))
+
+        return await self.db_pool.runInteraction(
+            "set_features_for_user", set_features_for_user_txn
+        )
+
+    @cached()
+    async def is_feature_enabled(
+        self, user_id: str, feature: "ExperimentalFeature"
+    ) -> bool:
+        """
+        Checks to see if a given feature is enabled for the user
+        Args:
+            user_id: the user to be queried on
+            feature: the feature in question
+        Returns:
+                True if the feature is enabled, False if it is not or if the feature was
+                not found.
+        """
+
+        if feature.is_globally_enabled(self.hs.config):
+            return True
+
+        # if it's not enabled globally, check if it is enabled per-user
+        res = await self.db_pool.simple_select_one_onecol(
+            table="per_user_experimental_features",
+            keyvalues={"user_id": user_id, "feature": feature},
+            retcol="enabled",
+            allow_none=True,
+            desc="get_feature_enabled",
+        )
+
+        # None and false are treated the same
+        db_enabled = bool(res)
+
+        return db_enabled