diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index c4de07a0a8..ae561a2da3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,9 +160,13 @@ class DataStore(
database,
stream_name="caches",
instance_name=hs.get_instance_name(),
- table="cache_invalidation_stream_by_instance",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
sequence_name="cache_invalidation_stream_seq",
writers=[],
)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index bad8260892..68896f34af 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_account_data = (
+ self._instance_name in hs.config.worker.writers.account_data
+ )
+
+ self._account_data_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[
+ ("room_account_data", "instance_name", "stream_id"),
+ ("room_tags_revisions", "instance_name", "stream_id"),
+ ("account_data", "instance_name", "stream_id"),
+ ],
+ sequence_name="account_data_sequence",
+ writers=hs.config.worker.writers.account_data,
+ )
+ else:
+ self._can_write_to_account_data = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+ else:
+ self._account_data_id_gen = SlavedIdTracker(
+ db_conn,
+ "room_account_data",
+ "stream_id",
+ extra_tables=[("room_tags_revisions", "stream_id")],
+ )
+
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs)
- @abc.abstractmethod
- def get_max_account_data_stream_id(self):
+ def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream ID for account data stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._account_data_id_gen.get_current_token()
@cached()
async def get_account_data_for_user(
@@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
)
)
-
-class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- self._account_data_id_gen = StreamIdGenerator(
- db_conn,
- "room_account_data",
- "stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
- )
-
- super().__init__(database, db_conn, hs)
-
- def get_max_account_data_stream_id(self) -> int:
- """Get the current max stream id for the private user data stream
-
- Returns:
- The maximum stream ID.
- """
- return self._account_data_id_gen.get_current_token()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ elif stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+ self.get_account_data_for_room_and_type.invalidate(
+ (row.user_id, row.room_id, row.data_type)
+ )
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id:
@@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns:
The maximum stream ID.
"""
+ assert self._can_write_to_account_data
+
async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"add_user_account_data",
@@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 58d3f71e45..31f70ac5ef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
db=database,
stream_name="to_device",
instance_name=self._instance_name,
- table="device_inbox",
- instance_column="instance_name",
- id_column="stream_id",
+ tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e5c03cc609..1b657191a9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(rotate_to_stream_ordering,),
)
+ def _remove_old_push_actions_before_txn(
+ self, txn, room_id, user_id, stream_ordering
+ ):
+ """
+ Purges old push actions for a user and room before a given
+ stream_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
+ Args:
+ txn: The transcation
+ room_id: Room ID to delete from
+ user_id: user ID to delete for
+ stream_ordering: The lowest stream ordering which will
+ not be deleted.
+ """
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id, user_id),
+ )
+
+ # We need to join on the events table to get the received_ts for
+ # event_push_actions and sqlite won't let us use a join in a delete so
+ # we can't just delete where received_ts < x. Furthermore we can
+ # only identify event_push_actions by a tuple of room_id, event_id
+ # we we can't use a subquery.
+ # Instead, we look up the stream ordering for the last event in that
+ # room received before the threshold time and delete event_push_actions
+ # in the room with a stream_odering before that.
+ txn.execute(
+ "DELETE FROM event_push_actions "
+ " WHERE user_id = ? AND room_id = ? AND "
+ " stream_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ )
+
+ txn.execute(
+ """
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- def _remove_old_push_actions_before_txn(
- self, txn, room_id, user_id, stream_ordering
- ):
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
-
- Args:
- txn: The transcation
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id, user_id),
- )
-
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
- )
-
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
-
def _action_has_highlight(actions):
for action in actions:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4732685f6e..71d823be72 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="events",
instance_name=hs.get_instance_name(),
- table="events",
- instance_column="instance_name",
- id_column="stream_ordering",
+ tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events,
)
@@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database,
stream_name="backfill",
instance_name=hs.get_instance_name(),
- table="events",
- instance_column="instance_name",
- id_column="stream_ordering",
+ tables=[("events", "instance_name", "stream_ordering")],
sequence_name="events_backfill_stream_seq",
positive=False,
writers=hs.config.worker.writers.events,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..e0e57f0578 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
- """This is an abstract base class where subclasses must implement
- `get_max_receipt_stream_id` which can be called in the initializer.
- """
-
+class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
+ self._instance_name = hs.get_instance_name()
+
+ if isinstance(database.engine, PostgresEngine):
+ self._can_write_to_receipts = (
+ self._instance_name in hs.config.worker.writers.receipts
+ )
+
+ self._receipts_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[("receipts_linearized", "instance_name", "stream_id")],
+ sequence_name="receipts_sequence",
+ writers=hs.config.worker.writers.receipts,
+ )
+ else:
+ self._can_write_to_receipts = True
+
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ else:
+ self._receipts_id_gen = SlavedIdTracker(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
- @abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
- raise NotImplementedError()
+ return self._receipts_id_gen.get_current_token()
@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
@@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
-class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- # We instantiate this first as the ReceiptsWorkerStore constructor
- # needs to be able to call get_max_receipt_stream_id
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
+ def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+ self.get_receipts_for_user.invalidate((user_id, receipt_type))
+ self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+ self.get_last_receipt_event_id_for_user.invalidate(
+ (user_id, room_id, receipt_type)
)
+ self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+ self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
+ for row in rows:
+ self.invalidate_caches_for_receipt(
+ row.room_id, row.receipt_type, row.user_id
+ )
+ self._receipts_stream_cache.entity_has_changed(row.room_id, token)
- super().__init__(database, db_conn, hs)
-
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
+ assert self._can_write_to_receipts
+
res = self.db_pool.simple_select_one_txn(
txn,
table="events",
@@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
return None
- txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
- txn.call_after(
- self._invalidate_get_users_with_receipts_in_room,
- room_id,
- receipt_type,
- user_id,
- )
- txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
- # FIXME: This shouldn't invalidate the whole cache
txn.call_after(
- self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+ self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
- txn.call_after(
- self.get_last_receipt_event_id_for_user.invalidate,
- (user_id, room_id, receipt_type),
- )
-
self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
@@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph
representations.
"""
+ assert self._can_write_to_receipts
+
if not event_ids:
return None
@@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
+ assert self._can_write_to_receipts
+
return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
@@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
+ assert self._can_write_to_receipts
+
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data),
},
)
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+ SELECT GREATEST(
+ (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+ )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 74da9c49f2..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}
-
-class TagsStore(TagsWorkerStore):
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
+ assert self._can_write_to_account_data
+
content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
+ assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id):
sql = (
@@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
room_id: The ID of the room.
next_id: The the revision to advance to.
"""
+ assert self._can_write_to_account_data
txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass
+
+
+class TagsStore(TagsWorkerStore):
+ pass
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..39a3ab1162 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -17,7 +17,7 @@ import logging
import threading
from collections import deque
from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
import attr
from typing_extensions import Deque
@@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
- stream_name: A name for the stream.
+ stream_name: A name for the stream, for use in the `stream_positions`
+ table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
- table: Database table associated with stream.
- instance_column: Column that stores the row's writer's instance name
- id_column: Column that stores the stream ID.
+ tables: List of tables associated with the stream. Tuple of table
+ name, column name that stores the writer's instance name, and
+ column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
@@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
db: DatabasePool,
stream_name: str,
instance_name: str,
- table: str,
- instance_column: str,
- id_column: str,
+ tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
@@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
- self._sequence_gen.check_consistency(
- db_conn, table=table, id_column=id_column, positive=positive
- )
+ for table, _, id_column in tables:
+ self._sequence_gen.check_consistency(
+ db_conn, table=table, id_column=id_column, positive=positive
+ )
# This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, table, instance_column, id_column)
+ self._load_current_ids(db_conn, tables)
def _load_current_ids(
- self, db_conn, table: str, instance_column: str, id_column: str
+ self, db_conn, tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -306,17 +306,22 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
- sql = """
- SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
- FROM %(table)s
- """ % {
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
- cur.execute(sql)
- (stream_id,) = cur.fetchone()
- self._persisted_upto_position = stream_id
+ max_stream_id = 1
+ for table, _, id_column in tables:
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+
+ max_stream_id = max(max_stream_id, stream_id)
+
+ self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
@@ -329,21 +334,28 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position
# table otherwise got out of date).
- sql = """
- SELECT %(instance)s, %(id)s FROM %(table)s
- WHERE ? %(cmp)s %(id)s
- """ % {
- "id": id_column,
- "table": table,
- "instance": instance_column,
- "cmp": "<=" if self._positive else ">=",
- }
- cur.execute(sql, (min_stream_id * self._return_factor,))
-
self._persisted_upto_position = min_stream_id
+ rows = []
+ for table, instance_column, id_column in tables:
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ cur.execute(sql, (min_stream_id * self._return_factor,))
+
+ rows.extend(cur)
+
+ # Sort so that we handle rows in order for each instance.
+ rows.sort()
+
with self._lock:
- for (instance, stream_id,) in cur:
+ for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
|