diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 95567826f2..a9843f6e17 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
- AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
@@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
):
super().__init__(database, db_conn, hs)
- # `_can_write_to_account_data` indicates whether the current worker is allowed
- # to write account data. A value of `True` implies that `_account_data_id_gen`
- # is an `AbstractStreamIdGenerator` and not just a tracker.
- self._account_data_id_gen: AbstractStreamIdTracker
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
+ self._account_data_id_gen: AbstractStreamIdGenerator
+
if isinstance(database.engine, PostgresEngine):
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -237,6 +234,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
return None
+ async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+ self, user_id: str, data_type: str
+ ) -> Optional[int]:
+ """
+ Returns:
+ The stream ID of the account data,
+ or None if there is no such account data.
+ """
+
+ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
+ sql = """
+ SELECT stream_id FROM account_data
+ WHERE user_id = ? AND account_data_type = ?
+ ORDER BY stream_id DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (user_id, data_type))
+
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ else:
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_latest_stream_id_for_global_account_data_by_type_for_user",
+ get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+ )
+
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
@@ -527,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
content_json = json_encoder.encode(content)
@@ -554,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str
- ) -> Optional[int]:
+ ) -> int:
"""Delete the room account data for the user of a given type.
Args:
@@ -567,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
data to delete.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_room_txn(
txn: LoggingTransaction, next_id: int
@@ -606,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
next_id,
)
- if not row_updated:
- return None
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_room_account_data_for_user.invalidate((user_id,))
- self.get_account_data_for_room.invalidate((user_id, room_id))
- self.get_account_data_for_room_and_type.prefill(
- (user_id, room_id, account_data_type), {}
- )
+ if row_updated:
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_room_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type), {}
+ )
return self._account_data_id_gen.get_current_token()
@@ -632,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
The maximum stream ID.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
@@ -722,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self,
user_id: str,
account_data_type: str,
- ) -> Optional[int]:
+ ) -> int:
"""
Delete a single piece of user account data by type.
@@ -739,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
to delete.
"""
assert self._can_write_to_account_data
- assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_user_txn(
txn: LoggingTransaction, next_id: int
@@ -809,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
next_id,
)
- if not row_updated:
- return None
-
- self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_global_account_data_for_user.invalidate((user_id,))
- self.get_global_account_data_by_type_for_user.prefill(
- (user_id, account_data_type), {}
- )
+ if row_updated:
+ self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+ self.get_global_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_by_type_for_user.prefill(
+ (user_id, account_data_type), {}
+ )
return self._account_data_id_gen.get_current_token()
|