diff options
author | H. Shay <hillerys@element.io> | 2023-05-10 09:04:56 -0700 |
---|---|---|
committer | H. Shay <hillerys@element.io> | 2023-05-10 09:04:56 -0700 |
commit | e156b84c3f7156b91b1c59297463aa58c25f3a93 (patch) | |
tree | 1ae8e6a74715dbe7fd5037946efd79d443005e4e | |
parent | move ExperimentalFeature definition to avoid circular import (diff) | |
download | synapse-e156b84c3f7156b91b1c59297463aa58c25f3a93.tar.xz |
consolidate logic checking config and db to one place
-rw-r--r-- | synapse/handlers/presence.py | 15 | ||||
-rw-r--r-- | synapse/rest/client/keys.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/pusher.py | 15 | ||||
-rw-r--r-- | synapse/storage/databases/main/experimental_features.py | 36 |
4 files changed, 35 insertions, 37 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index f681beb92e..9516d3fbf8 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -63,6 +63,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore +from synapse.storage.databases.main.experimental_features import ExperimentalFeature from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -605,11 +606,8 @@ class WorkerPresenceHandler(BasePresenceHandler): PresenceState.BUSY, ) - busy_presence_enabled = ( - await self.hs.get_datastores().main.get_feature_enabled( - target_user.to_string(), "msc3026" - ) - or self.hs.config.experimental.msc3026_enabled + busy_presence_enabled = await self.hs.get_datastores().main.get_feature_enabled( + target_user.to_string(), ExperimentalFeature.MSC3026 ) if presence not in valid_presence or ( @@ -1241,11 +1239,8 @@ class PresenceHandler(BasePresenceHandler): PresenceState.BUSY, ) - busy_presence_enabled = ( - await self.hs.get_datastores().main.get_feature_enabled( - target_user.to_string(), "msc3026" - ) - or self.hs.config.experimental.msc3026_enabled + busy_presence_enabled = await self.hs.get_datastores().main.get_feature_enabled( + target_user.to_string(), ExperimentalFeature.MSC3026 ) if presence not in valid_presence or ( diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 5d8b9239a4..bc120d262a 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -31,6 +31,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.storage.databases.main.experimental_features import ExperimentalFeature from synapse.types import JsonDict, StreamToken from synapse.util.cancellation import cancellable @@ -375,9 +376,8 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - msc3967_enabled = ( - await self.hs.get_datastores().main.get_feature_enabled(user_id, "msc3967") - or self.hs.config.experimental.msc2654_enabled + msc3967_enabled = await self.hs.get_datastores().main.get_feature_enabled( + user_id, ExperimentalFeature.MSC3967 ) if msc3967_enabled: diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 40c4b6adc3..5fcb19beb9 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -27,6 +27,7 @@ from synapse.http.site import SynapseRequest from synapse.push import PusherConfigException from synapse.rest.client._base import client_patterns from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource +from synapse.storage.databases.main.experimental_features import ExperimentalFeature from synapse.types import JsonDict if TYPE_CHECKING: @@ -53,11 +54,8 @@ class PushersRestServlet(RestServlet): pusher_dicts = [p.as_dict() for p in pushers] - msc3881_enabled = ( - await self.hs.get_datastores().main.get_feature_enabled( - user.to_string(), "msc3881" - ) - or self.hs.config.experimental.msc3881_enabled + msc3881_enabled = await self.hs.get_datastores().main.get_feature_enabled( + user.to_string(), ExperimentalFeature.MSC3881 ) for pusher in pusher_dicts: @@ -118,11 +116,8 @@ class PushersSetRestServlet(RestServlet): append = content["append"] enabled = True - msc3881_enabled = ( - await self.hs.get_datastores().main.get_feature_enabled( - user.to_string(), "msc3881" - ) - or self.hs.config.experimental.msc3881_enabled + msc3881_enabled = await self.hs.get_datastores().main.get_feature_enabled( + user.to_string(), ExperimentalFeature.MSC3881 ) if msc3881_enabled and "org.matrix.msc3881.enabled" in content: diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py index 197cdb21ce..7a0a2caa0e 100644 --- a/synapse/storage/databases/main/experimental_features.py +++ b/synapse/storage/databases/main/experimental_features.py @@ -34,6 +34,7 @@ class ExperimentalFeature(str, Enum): MSC3881 = "msc3881" MSC3967 = "msc3967" + class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): def __init__( self, @@ -84,19 +85,32 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): await self.invalidate_cache_and_stream("list_enabled_features", (user,)) - async def get_feature_enabled(self, user_id: str, feature: str) -> bool: + async def get_feature_enabled( + self, user_id: str, feature: "ExperimentalFeature" + ) -> bool: """ Checks to see if a given feature is enabled for the user + Args: - user: - the user to be queried on - feature: - the feature in question + user_id: the user to be queried on + feature: the feature in question Returns: True if the feature is enabled, False if it is not or if the feature was not found. """ + # check first if feature is enabled in the config + if feature == ExperimentalFeature.MSC3026: + globally_enabled = self.hs.config.experimental.msc3026_enabled + elif feature == ExperimentalFeature.MSC3881: + globally_enabled = self.hs.config.experimental.msc3881_enabled + else: + globally_enabled = self.hs.config.experimental.msc3967_enabled + + if globally_enabled: + return globally_enabled + + # if it's not enabled globally, check if it is enabled per-user res = await self.db_pool.simple_select_one( "per_user_experimental_features", {"user_id": user_id, "feature": feature}, @@ -104,13 +118,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): allow_none=True, ) - if not res: - res = {"enabled": False} - - # Deal with Sqlite boolean return values - if res["enabled"] == 0: - res["enabled"] = False - if res["enabled"] == 1: - res["enabled"] = True + # None and false are treated the same + db_enabled = bool(res) - return res["enabled"] + return db_enabled |