diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bd0cfa7f32..01206950a9 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
+ # If we're using SQLite, then boolean values are integers. This is
+ # troublesome since some code using the return value of this method might
+ # expect it to be a boolean, or will expose it to clients (in responses).
+ r["enabled"] = bool(r["enabled"])
+
yield PusherConfig(**r)
async def get_pushers_by_app_id_and_pushkey(
@@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
return await self.get_pushers_by({"user_name": user_id})
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
- ret = await self.db_pool.simple_select_list(
- "pushers",
- keyvalues,
- [
- "id",
- "user_name",
- "access_token",
- "profile_tag",
- "kind",
- "app_id",
- "app_display_name",
- "device_display_name",
- "pushkey",
- "ts",
- "lang",
- "data",
- "last_stream_ordering",
- "last_success",
- "failing_since",
- ],
+ """Retrieve pushers that match the given criteria.
+
+ Args:
+ keyvalues: A {column: value} dictionary.
+
+ Returns:
+ The pushers for which the given columns have the given values.
+ """
+
+ def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ # We could technically use simple_select_list here, but we need to call
+ # COALESCE on the 'enabled' column. While it is technically possible to give
+ # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
+ # feels a bit hacky, so it's probably better to just inline the query.
+ sql = """
+ SELECT
+ id, user_name, access_token, profile_tag, kind, app_id,
+ app_display_name, device_display_name, pushkey, ts, lang, data,
+ last_stream_ordering, last_success, failing_since,
+ COALESCE(enabled, TRUE) AS enabled, device_id
+ FROM pushers
+ """
+
+ sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
+
+ txn.execute(sql, list(keyvalues.values()))
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
+ func=get_pushers_by_txn,
)
+
return self._decode_pushers_rows(ret)
- async def get_all_pushers(self) -> Iterator[PusherConfig]:
- def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
- txn.execute("SELECT * FROM pushers")
+ async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
+ def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
+ txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
+ return await self.db_pool.runInteraction(
+ "get_enabled_pushers", get_enabled_pushers_txn
+ )
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -458,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
-class PusherStore(PusherWorkerStore):
+class PusherBackgroundUpdatesStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_update_handler(
+ "set_device_id_for_pushers", self._set_device_id_for_pushers
+ )
+
+ async def _set_device_id_for_pushers(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update to populate the device_id column of the pushers table."""
+ last_pusher_id = progress.get("pusher_id", 0)
+
+ def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int:
+ txn.execute(
+ """
+ SELECT p.id, at.device_id
+ FROM pushers AS p
+ INNER JOIN access_tokens AS at
+ ON p.access_token = at.id
+ WHERE
+ p.access_token IS NOT NULL
+ AND at.device_id IS NOT NULL
+ AND p.id > ?
+ ORDER BY p.id
+ LIMIT ?
+ """,
+ (last_pusher_id, batch_size),
+ )
+
+ rows = self.db_pool.cursor_to_dict(txn)
+ if len(rows) == 0:
+ return 0
+
+ self.db_pool.simple_update_many_txn(
+ txn=txn,
+ table="pushers",
+ key_names=("id",),
+ key_values=[(row["id"],) for row in rows],
+ value_names=("device_id",),
+ value_values=[(row["device_id"],) for row in rows],
+ )
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]}
+ )
+
+ return len(rows)
+
+ nb_processed = await self.db_pool.runInteraction(
+ "set_device_id_for_pushers", set_device_id_for_pushers_txn
+ )
+
+ if nb_processed < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "set_device_id_for_pushers"
+ )
+
+ return nb_processed
+
+
+class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -476,6 +562,8 @@ class PusherStore(PusherWorkerStore):
data: Optional[JsonDict],
last_stream_ordering: int,
profile_tag: str = "",
+ enabled: bool = True,
+ device_id: Optional[str] = None,
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -494,6 +582,8 @@ class PusherStore(PusherWorkerStore):
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
+ "enabled": enabled,
+ "device_id": device_id,
},
desc="add_pusher",
lock=False,
|