summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17392.misc1
-rw-r--r--docs/admin_api/experimental_features.md17
-rw-r--r--synapse/rest/admin/experimental_features.py11
-rw-r--r--synapse/rest/client/pusher.py29
-rw-r--r--synapse/rest/client/versions.py20
-rw-r--r--synapse/storage/databases/main/experimental_features.py64
-rw-r--r--tests/push/test_http.py82
-rw-r--r--tests/rest/admin/test_admin.py14
8 files changed, 189 insertions, 49 deletions
diff --git a/changelog.d/17392.misc b/changelog.d/17392.misc
new file mode 100644
index 0000000000..76e3976e28
--- /dev/null
+++ b/changelog.d/17392.misc
@@ -0,0 +1 @@
+Finish up work to allow per-user feature flags.
diff --git a/docs/admin_api/experimental_features.md b/docs/admin_api/experimental_features.md
index 07b630915d..250cfc13a3 100644
--- a/docs/admin_api/experimental_features.md
+++ b/docs/admin_api/experimental_features.md
@@ -1,21 +1,16 @@
 # Experimental Features API
 
 This API allows a server administrator to enable or disable some experimental features on a per-user
-basis. The currently supported features are: 
-- [MSC3026](https://github.com/matrix-org/matrix-spec-proposals/pull/3026): busy 
-presence state enabled
-- [MSC3881](https://github.com/matrix-org/matrix-spec-proposals/pull/3881): enable remotely toggling push notifications 
-for another client 
-- [MSC3967](https://github.com/matrix-org/matrix-spec-proposals/pull/3967): do not require
-UIA when first uploading cross-signing keys. 
-
+basis. The currently supported features are:
+- [MSC3881](https://github.com/matrix-org/matrix-spec-proposals/pull/3881): enable remotely toggling push notifications
+for another client
 
 To use it, you will need to authenticate by providing an `access_token`
 for a server admin: see [Admin API](../usage/administration/admin_api/).
 
 ## Enabling/Disabling Features
 
-This API allows a server administrator to enable experimental features for a given user. The request must 
+This API allows a server administrator to enable experimental features for a given user. The request must
 provide a body containing the user id and listing the features to enable/disable in the following format:
 ```json
 {
@@ -35,7 +30,7 @@ PUT /_synapse/admin/v1/experimental_features/<user_id>
 ```
 
 ## Listing Enabled Features
- 
+
 To list which features are enabled/disabled for a given user send a request to the following API:
 
 ```
@@ -52,4 +47,4 @@ user like so:
       "msc3967": false
    }
 }
-```
\ No newline at end of file
+```
diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py
index c5a00c490c..c1559c92f7 100644
--- a/synapse/rest/admin/experimental_features.py
+++ b/synapse/rest/admin/experimental_features.py
@@ -31,7 +31,9 @@ from synapse.rest.admin import admin_patterns, assert_requester_is_admin
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
-    from synapse.server import HomeServer
+    from typing_extensions import assert_never
+
+    from synapse.server import HomeServer, HomeServerConfig
 
 
 class ExperimentalFeature(str, Enum):
@@ -39,9 +41,14 @@ class ExperimentalFeature(str, Enum):
     Currently supported per-user features
     """
 
-    MSC3026 = "msc3026"
     MSC3881 = "msc3881"
 
+    def is_globally_enabled(self, config: "HomeServerConfig") -> bool:
+        if self is ExperimentalFeature.MSC3881:
+            return config.experimental.msc3881_enabled
+
+        assert_never(self)
+
 
 class ExperimentalFeaturesRestServlet(RestServlet):
     """
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 9957d2fcbe..a455f95a26 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
 )
 from synapse.http.site import SynapseRequest
 from synapse.push import PusherConfigException
+from synapse.rest.admin.experimental_features import ExperimentalFeature
 from synapse.rest.client._base import client_patterns
 from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
 from synapse.types import JsonDict
@@ -49,20 +50,22 @@ class PushersRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
+        self._store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
 
-        pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(
-            user.to_string()
+        msc3881_enabled = await self._store.is_feature_enabled(
+            user_id, ExperimentalFeature.MSC3881
         )
 
+        pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(user_id)
+
         pusher_dicts = [p.as_dict() for p in pushers]
 
         for pusher in pusher_dicts:
-            if self._msc3881_enabled:
+            if msc3881_enabled:
                 pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
                 pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
             del pusher["enabled"]
@@ -80,11 +83,15 @@ class PushersSetRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
         self.pusher_pool = self.hs.get_pusherpool()
-        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
+        self._store = hs.get_datastores().main
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
+
+        msc3881_enabled = await self._store.is_feature_enabled(
+            user_id, ExperimentalFeature.MSC3881
+        )
 
         content = parse_json_object_from_request(request)
 
@@ -95,7 +102,7 @@ class PushersSetRestServlet(RestServlet):
             and content["kind"] is None
         ):
             await self.pusher_pool.remove_pusher(
-                content["app_id"], content["pushkey"], user_id=user.to_string()
+                content["app_id"], content["pushkey"], user_id=user_id
             )
             return 200, {}
 
@@ -120,19 +127,19 @@ class PushersSetRestServlet(RestServlet):
             append = content["append"]
 
         enabled = True
-        if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
+        if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
             enabled = content["org.matrix.msc3881.enabled"]
 
         if not append:
             await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
                 app_id=content["app_id"],
                 pushkey=content["pushkey"],
-                not_user_id=user.to_string(),
+                not_user_id=user_id,
             )
 
         try:
             await self.pusher_pool.add_or_update_pusher(
-                user_id=user.to_string(),
+                user_id=user_id,
                 kind=content["kind"],
                 app_id=content["app_id"],
                 app_display_name=content["app_display_name"],
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index f428158139..e01e5f542a 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -25,11 +25,11 @@ import logging
 import re
 from typing import TYPE_CHECKING, Tuple
 
-from twisted.web.server import Request
-
 from synapse.api.constants import RoomCreationPreset
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin.experimental_features import ExperimentalFeature
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -45,6 +45,8 @@ class VersionsRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
 
         # Calculate these once since they shouldn't change after start-up.
         self.e2ee_forced_public = (
@@ -60,7 +62,17 @@ class VersionsRestServlet(RestServlet):
             in self.config.room.encryption_enabled_by_default_for_room_presets
         )
 
-    def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        msc3881_enabled = self.config.experimental.msc3881_enabled
+
+        if self.auth.has_access_token(request):
+            requester = await self.auth.get_user_by_req(request)
+            user_id = requester.user.to_string()
+
+            msc3881_enabled = await self.store.is_feature_enabled(
+                user_id, ExperimentalFeature.MSC3881
+            )
+
         return (
             200,
             {
@@ -124,7 +136,7 @@ class VersionsRestServlet(RestServlet):
                     # TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
                     "org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
                     # Adds support for remotely enabling/disabling pushers, as per MSC3881
-                    "org.matrix.msc3881": self.config.experimental.msc3881_enabled,
+                    "org.matrix.msc3881": msc3881_enabled,
                     # Adds support for filtering /messages by event relation.
                     "org.matrix.msc3874": self.config.experimental.msc3874_enabled,
                     # Adds support for simple HTTP rendezvous as per MSC3886
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index fbb98d8f63..d980c57fa8 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -21,7 +21,11 @@
 
 from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
 
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.util.caches.descriptors import cached
 
@@ -73,12 +77,54 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
             features:
                 pairs of features and True/False for whether the feature should be enabled
         """
-        for feature, enabled in features.items():
-            await self.db_pool.simple_upsert(
-                table="per_user_experimental_features",
-                keyvalues={"feature": feature, "user_id": user},
-                values={"enabled": enabled},
-                insertion_values={"user_id": user, "feature": feature},
-            )
 
-            await self.invalidate_cache_and_stream("list_enabled_features", (user,))
+        def set_features_for_user_txn(txn: LoggingTransaction) -> None:
+            for feature, enabled in features.items():
+                self.db_pool.simple_upsert_txn(
+                    txn,
+                    table="per_user_experimental_features",
+                    keyvalues={"feature": feature, "user_id": user},
+                    values={"enabled": enabled},
+                    insertion_values={"user_id": user, "feature": feature},
+                )
+
+                self._invalidate_cache_and_stream(
+                    txn, self.is_feature_enabled, (user, feature)
+                )
+
+            self._invalidate_cache_and_stream(txn, self.list_enabled_features, (user,))
+
+        return await self.db_pool.runInteraction(
+            "set_features_for_user", set_features_for_user_txn
+        )
+
+    @cached()
+    async def is_feature_enabled(
+        self, user_id: str, feature: "ExperimentalFeature"
+    ) -> bool:
+        """
+        Checks to see if a given feature is enabled for the user
+        Args:
+            user_id: the user to be queried on
+            feature: the feature in question
+        Returns:
+                True if the feature is enabled, False if it is not or if the feature was
+                not found.
+        """
+
+        if feature.is_globally_enabled(self.hs.config):
+            return True
+
+        # if it's not enabled globally, check if it is enabled per-user
+        res = await self.db_pool.simple_select_one_onecol(
+            table="per_user_experimental_features",
+            keyvalues={"user_id": user_id, "feature": feature},
+            retcol="enabled",
+            allow_none=True,
+            desc="get_feature_enabled",
+        )
+
+        # None and false are treated the same
+        db_enabled = bool(res)
+
+        return db_enabled
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index dce00d8b7f..bcca472617 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -26,7 +26,8 @@ from twisted.test.proto_helpers import MemoryReactor
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
 from synapse.push import PusherConfig, PusherConfigException
-from synapse.rest.client import login, push_rule, pusher, receipts, room
+from synapse.rest.admin.experimental_features import ExperimentalFeature
+from synapse.rest.client import login, push_rule, pusher, receipts, room, versions
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -42,6 +43,7 @@ class HTTPPusherTests(HomeserverTestCase):
         receipts.register_servlets,
         push_rule.register_servlets,
         pusher.register_servlets,
+        versions.register_servlets,
     ]
     user_id = True
     hijack_auth = False
@@ -969,6 +971,84 @@ class HTTPPusherTests(HomeserverTestCase):
             lookup_result.device_id,
         )
 
+    def test_device_id_feature_flag(self) -> None:
+        """Tests that a pusher created with a given device ID shows that device ID in
+        GET /pushers requests when feature is enabled for the user
+        """
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # We create the pusher with an HTTP request rather than with
+        # _make_user_with_pusher so that we can test the device ID is correctly set when
+        # creating a pusher via an API call.
+        self.make_request(
+            method="POST",
+            path="/pushers/set",
+            content={
+                "kind": "http",
+                "app_id": "m.http",
+                "app_display_name": "HTTP Push Notifications",
+                "device_display_name": "pushy push",
+                "pushkey": "a@example.com",
+                "lang": "en",
+                "data": {"url": "http://example.com/_matrix/push/v1/notify"},
+            },
+            access_token=access_token,
+        )
+
+        # Look up the user info for the access token so we can compare the device ID.
+        store = self.hs.get_datastores().main
+        lookup_result = self.get_success(store.get_user_by_access_token(access_token))
+        assert lookup_result is not None
+
+        # Check field is not there before we enable the feature flag
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertNotIn(
+            "org.matrix.msc3881.device_id", channel.json_body["pushers"][0]
+        )
+
+        self.get_success(
+            store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True})
+        )
+
+        # Get the user's devices and check it has the correct device ID.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertEqual(
+            channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
+            lookup_result.device_id,
+        )
+
+    def test_msc3881_client_versions_flag(self) -> None:
+        """Tests that MSC3881 only appears in /versions if user has it enabled."""
+
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Check feature is disabled in /versions
+        channel = self.make_request(
+            "GET", "/_matrix/client/versions", access_token=access_token
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"])
+
+        # Enable feature for user
+        self.get_success(
+            self.hs.get_datastores().main.set_features_for_user(
+                user_id, {ExperimentalFeature.MSC3881: True}
+            )
+        )
+
+        # Check feature is now enabled in /versions for user
+        channel = self.make_request(
+            "GET", "/_matrix/client/versions", access_token=access_token
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"])
+
     @override_config({"push": {"jitter_delay": "10s"}})
     def test_jitter(self) -> None:
         """Tests that enabling jitter actually delays sending push."""
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5f6f7213b3..6351326fff 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -384,7 +384,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
             "PUT",
             url,
             content={
-                "features": {"msc3026": True, "msc3881": True},
+                "features": {"msc3881": True},
             },
             access_token=self.admin_user_tok,
         )
@@ -401,10 +401,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(
             True,
-            channel.json_body["features"]["msc3026"],
-        )
-        self.assertEqual(
-            True,
             channel.json_body["features"]["msc3881"],
         )
 
@@ -413,7 +409,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT",
             url,
-            content={"features": {"msc3026": False}},
+            content={"features": {"msc3881": False}},
             access_token=self.admin_user_tok,
         )
         self.assertEqual(channel.code, 200)
@@ -429,10 +425,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(
             False,
-            channel.json_body["features"]["msc3026"],
-        )
-        self.assertEqual(
-            True,
             channel.json_body["features"]["msc3881"],
         )
 
@@ -441,7 +433,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT",
             url,
-            content={"features": {"msc3026": False}},
+            content={"features": {"msc3881": False}},
             access_token=self.admin_user_tok,
         )
         self.assertEqual(channel.code, 200)