summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/pusher.py69
1 files changed, 45 insertions, 24 deletions
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bd0cfa7f32..ee55b8c4a9 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
                 )
                 continue
 
+            # If we're using SQLite, then boolean values are integers. This is
+            # troublesome since some code using the return value of this method might
+            # expect it to be a boolean, or will expose it to clients (in responses).
+            r["enabled"] = bool(r["enabled"])
+
             yield PusherConfig(**r)
 
     async def get_pushers_by_app_id_and_pushkey(
@@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
         return await self.get_pushers_by({"user_name": user_id})
 
     async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
-        ret = await self.db_pool.simple_select_list(
-            "pushers",
-            keyvalues,
-            [
-                "id",
-                "user_name",
-                "access_token",
-                "profile_tag",
-                "kind",
-                "app_id",
-                "app_display_name",
-                "device_display_name",
-                "pushkey",
-                "ts",
-                "lang",
-                "data",
-                "last_stream_ordering",
-                "last_success",
-                "failing_since",
-            ],
+        """Retrieve pushers that match the given criteria.
+
+        Args:
+            keyvalues: A {column: value} dictionary.
+
+        Returns:
+            The pushers for which the given columns have the given values.
+        """
+
+        def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+            # We could technically use simple_select_list here, but we need to call
+            # COALESCE on the 'enabled' column. While it is technically possible to give
+            # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
+            # feels a bit hacky, so it's probably better to just inline the query.
+            sql = """
+            SELECT
+                id, user_name, access_token, profile_tag, kind, app_id,
+                app_display_name, device_display_name, pushkey, ts, lang, data,
+                last_stream_ordering, last_success, failing_since,
+                COALESCE(enabled, TRUE) AS enabled
+            FROM pushers
+            """
+
+            sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
+
+            txn.execute(sql, list(keyvalues.values()))
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        ret = await self.db_pool.runInteraction(
             desc="get_pushers_by",
+            func=get_pushers_by_txn,
         )
+
         return self._decode_pushers_rows(ret)
 
-    async def get_all_pushers(self) -> Iterator[PusherConfig]:
-        def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
-            txn.execute("SELECT * FROM pushers")
+    async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
+        def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
+            txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
             rows = self.db_pool.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
+        return await self.db_pool.runInteraction(
+            "get_enabled_pushers", get_enabled_pushers_txn
+        )
 
     async def get_all_updated_pushers_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore):
         data: Optional[JsonDict],
         last_stream_ordering: int,
         profile_tag: str = "",
+        enabled: bool = True,
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore):
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
+                    "enabled": enabled,
                 },
                 desc="add_pusher",
                 lock=False,