summary refs log tree commit diff
diff options
context:
space:
mode:
authorH. Shay <hillerys@element.io>2023-05-10 09:04:56 -0700
committerH. Shay <hillerys@element.io>2023-05-10 09:04:56 -0700
commite156b84c3f7156b91b1c59297463aa58c25f3a93 (patch)
tree1ae8e6a74715dbe7fd5037946efd79d443005e4e
parentmove ExperimentalFeature definition to avoid circular import (diff)
downloadsynapse-e156b84c3f7156b91b1c59297463aa58c25f3a93.tar.xz
consolidate logic checking config and db to one place
-rw-r--r--synapse/handlers/presence.py15
-rw-r--r--synapse/rest/client/keys.py6
-rw-r--r--synapse/rest/client/pusher.py15
-rw-r--r--synapse/storage/databases/main/experimental_features.py36
4 files changed, 35 insertions, 37 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index f681beb92e..9516d3fbf8 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -63,6 +63,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.storage.databases.main import DataStore
+from synapse.storage.databases.main.experimental_features import ExperimentalFeature
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
@@ -605,11 +606,8 @@ class WorkerPresenceHandler(BasePresenceHandler):
             PresenceState.BUSY,
         )
 
-        busy_presence_enabled = (
-            await self.hs.get_datastores().main.get_feature_enabled(
-                target_user.to_string(), "msc3026"
-            )
-            or self.hs.config.experimental.msc3026_enabled
+        busy_presence_enabled = await self.hs.get_datastores().main.get_feature_enabled(
+            target_user.to_string(), ExperimentalFeature.MSC3026
         )
 
         if presence not in valid_presence or (
@@ -1241,11 +1239,8 @@ class PresenceHandler(BasePresenceHandler):
             PresenceState.BUSY,
         )
 
-        busy_presence_enabled = (
-            await self.hs.get_datastores().main.get_feature_enabled(
-                target_user.to_string(), "msc3026"
-            )
-            or self.hs.config.experimental.msc3026_enabled
+        busy_presence_enabled = await self.hs.get_datastores().main.get_feature_enabled(
+            target_user.to_string(), ExperimentalFeature.MSC3026
         )
 
         if presence not in valid_presence or (
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 5d8b9239a4..bc120d262a 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -31,6 +31,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import log_kv, set_tag
 from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
+from synapse.storage.databases.main.experimental_features import ExperimentalFeature
 from synapse.types import JsonDict, StreamToken
 from synapse.util.cancellation import cancellable
 
@@ -375,9 +376,8 @@ class SigningKeyUploadServlet(RestServlet):
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
 
-        msc3967_enabled = (
-            await self.hs.get_datastores().main.get_feature_enabled(user_id, "msc3967")
-            or self.hs.config.experimental.msc2654_enabled
+        msc3967_enabled = await self.hs.get_datastores().main.get_feature_enabled(
+            user_id, ExperimentalFeature.MSC3967
         )
 
         if msc3967_enabled:
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 40c4b6adc3..5fcb19beb9 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -27,6 +27,7 @@ from synapse.http.site import SynapseRequest
 from synapse.push import PusherConfigException
 from synapse.rest.client._base import client_patterns
 from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
+from synapse.storage.databases.main.experimental_features import ExperimentalFeature
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -53,11 +54,8 @@ class PushersRestServlet(RestServlet):
 
         pusher_dicts = [p.as_dict() for p in pushers]
 
-        msc3881_enabled = (
-            await self.hs.get_datastores().main.get_feature_enabled(
-                user.to_string(), "msc3881"
-            )
-            or self.hs.config.experimental.msc3881_enabled
+        msc3881_enabled = await self.hs.get_datastores().main.get_feature_enabled(
+            user.to_string(), ExperimentalFeature.MSC3881
         )
 
         for pusher in pusher_dicts:
@@ -118,11 +116,8 @@ class PushersSetRestServlet(RestServlet):
             append = content["append"]
 
         enabled = True
-        msc3881_enabled = (
-            await self.hs.get_datastores().main.get_feature_enabled(
-                user.to_string(), "msc3881"
-            )
-            or self.hs.config.experimental.msc3881_enabled
+        msc3881_enabled = await self.hs.get_datastores().main.get_feature_enabled(
+            user.to_string(), ExperimentalFeature.MSC3881
         )
 
         if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index 197cdb21ce..7a0a2caa0e 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -34,6 +34,7 @@ class ExperimentalFeature(str, Enum):
     MSC3881 = "msc3881"
     MSC3967 = "msc3967"
 
+
 class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -84,19 +85,32 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
 
             await self.invalidate_cache_and_stream("list_enabled_features", (user,))
 
-    async def get_feature_enabled(self, user_id: str, feature: str) -> bool:
+    async def get_feature_enabled(
+        self, user_id: str, feature: "ExperimentalFeature"
+    ) -> bool:
         """
         Checks to see if a given feature is enabled for the user
+
         Args:
-            user:
-                the user to be queried on
-            feature:
-                the feature in question
+            user_id: the user to be queried on
+            feature: the feature in question
         Returns:
                 True if the feature is enabled, False if it is not or if the feature was
                 not found.
         """
 
+        # check first if feature is enabled in the config
+        if feature == ExperimentalFeature.MSC3026:
+            globally_enabled = self.hs.config.experimental.msc3026_enabled
+        elif feature == ExperimentalFeature.MSC3881:
+            globally_enabled = self.hs.config.experimental.msc3881_enabled
+        else:
+            globally_enabled = self.hs.config.experimental.msc3967_enabled
+
+        if globally_enabled:
+            return globally_enabled
+
+        # if it's not enabled globally, check if it is enabled per-user
         res = await self.db_pool.simple_select_one(
             "per_user_experimental_features",
             {"user_id": user_id, "feature": feature},
@@ -104,13 +118,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
             allow_none=True,
         )
 
-        if not res:
-            res = {"enabled": False}
-
-        # Deal with Sqlite boolean return values
-        if res["enabled"] == 0:
-            res["enabled"] = False
-        if res["enabled"] == 1:
-            res["enabled"] = True
+        # None and false are treated the same
+        db_enabled = bool(res)
 
-        return res["enabled"]
+        return db_enabled