diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index bad8260892..68896f34af 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_account_data = (
+ self._instance_name in hs.config.worker.writers.account_data
+ )
+
+ self._account_data_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[
+ ("room_account_data", "instance_name", "stream_id"),
+ ("room_tags_revisions", "instance_name", "stream_id"),
+ ("account_data", "instance_name", "stream_id"),
+ ],
+ sequence_name="account_data_sequence",
+ writers=hs.config.worker.writers.account_data,
+ )
+ else:
+ self._can_write_to_account_data = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+ else:
+ self._account_data_id_gen = SlavedIdTracker(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs)
- @abc.abstractmethod
- def get_max_account_data_stream_id(self):
+ def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._account_data_id_gen.get_current_token()
@cached()
async def get_account_data_for_user(
@@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
)
)
-
-class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- self._account_data_id_gen = StreamIdGenerator(
- db_conn,
- "room_account_data",
- "stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_account_data_stream_id(self) -> int:
- """Get the current max stream id for the private user data stream
-
- Returns:
- The maximum stream ID.
- """
- return self._account_data_id_gen.get_current_token()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ elif stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+ self.get_account_data_for_room_and_type.invalidate(
+ (row.user_id, row.room_id, row.data_type)
+ )
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
@@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"add_user_account_data",
@@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ pass
|