diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f62..59073c0a42 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self.db = database
self.rand = random.SystemRandom()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ pass
+
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index ceba10882c..cd2a1f0461 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
ChainedIdGenerator,
IdGenerator,
+ MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
@@ -112,8 +113,8 @@ class DataStore(
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
- CacheInvalidationStore,
UIAuthStore,
+ CacheInvalidationWorkerStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
@@ -170,8 +171,14 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = StreamIdGenerator(
- db_conn, "cache_invalidation_stream", "stream_id"
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ instance_name="master",
+ table="cache_invalidation_stream_by_instance",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="cache_invalidation_stream_seq",
)
else:
self._cache_id_gen = None
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3fe8..342a87a46b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,11 +16,10 @@
import itertools
import logging
-from typing import Any, Iterable, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Any, Iterable, Optional
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def get_all_updated_caches(self, last_id, current_id, limit):
+ def __init__(self, database: Database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._instance_name = hs.get_instance_name()
+
+ async def get_all_updated_caches(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ):
+ """Fetches cache invalidation rows between the two given IDs written
+ by the given instance. Returns at most `limit` rows.
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return []
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
+ sql = """
+ SELECT stream_id, cache_func, keys, invalidation_ts
+ FROM cache_invalidation_stream_by_instance
+ WHERE stream_id > ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "caches":
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
- async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
- """Invalidates the cache and adds it to the cache stream so slaves
- will know to invalidate their caches.
+ for row in rows:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
+ if row.keys is None:
+ raise Exception(
+ "Can't send an 'invalidate all' for current state cache"
+ )
- This should only be used to invalidate caches where slaves won't
- otherwise know from other replication streams that the cache should
- be invalidated.
- """
- cache_func = getattr(self, cache_name, None)
- if not cache_func:
- return
-
- cache_func.invalidate(keys)
- await self.runInteraction(
- "invalidate_cache_and_stream",
- self._send_invalidation_to_replication,
- cache_func.__name__,
- keys,
- )
+ room_id = row.keys[0]
+ members_changed = set(row.keys[1:])
+ self._invalidate_state_caches(room_id, members_changed)
+ else:
+ self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
+ stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
@@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
self.db.simple_insert_txn(
txn,
- table="cache_invalidation_stream",
+ table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
- def get_cache_stream_token(self):
+ def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
+ return self._cache_id_gen.get_current_token(instance_name)
else:
return 0
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ cache_func TEXT NOT NULL,
+ keys TEXT[],
+ invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1712932f31..640f242584 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__))
|