diff options
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 43bf0f649a..637a938bac 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -369,17 +369,25 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + def get_type_stream_id_for_appservice_txn(txn): stream_id_type = "%s_stream_id" % type txn.execute( - "SELECT ? FROM application_services_state WHERE as_id=?", - (stream_id_type, service.id,), + # We do NOT want to escape `stream_id_type`. + "SELECT %s FROM application_services_state WHERE as_id=?" + % stream_id_type, + (service.id,), ) - last_txn_id = txn.fetchone() - if last_txn_id is None or last_txn_id[0] is None: # no row exists + last_stream_id = txn.fetchone() + if last_stream_id is None or last_stream_id[0] is None: # no row exists return 0 else: - return int(last_txn_id[0]) + return int(last_stream_id[0]) return await self.db_pool.runInteraction( "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn @@ -388,11 +396,18 @@ class ApplicationServiceTransactionWorkerStore( async def set_type_stream_id_for_appservice( self, service: ApplicationService, type: str, pos: int ) -> None: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + def set_type_stream_id_for_appservice_txn(txn): stream_id_type = "%s_stream_id" % type txn.execute( - "UPDATE ? SET device_list_stream_id = ? WHERE as_id=?", - (stream_id_type, pos, service.id), + "UPDATE application_services_state SET %s = ? WHERE as_id=?" + % stream_id_type, + (pos, service.id), ) await self.db_pool.runInteraction( |