diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/app/synchrotron.py | 13 | ||||
-rw-r--r-- | synapse/replication/resource.py | 21 | ||||
-rw-r--r-- | synapse/replication/slave/storage/_base.py | 30 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 11 | ||||
-rw-r--r-- | synapse/storage/_base.py | 47 |
5 files changed, 92 insertions, 30 deletions
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 3dca1c37a0..207a75d89e 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -338,16 +338,10 @@ class SynchrotronServer(HomeServer): http_client = self.get_simple_http_client() store = self.get_datastore() replication_url = self.config.worker_replication_url - clock = self.get_clock() notifier = self.get_notifier() presence_handler = self.get_presence_handler() typing_handler = self.get_typing_handler() - def expire_broken_caches(): - store.who_forgot_in_room.invalidate_all() - store.get_presence_list_accepted.invalidate_all() - store.get_presence_list_observers_accepted.invalidate_all() - def notify_from_stream( result, stream_name, stream_key, room=None, user=None ): @@ -409,19 +403,12 @@ class SynchrotronServer(HomeServer): result, "typing", "typing_key", room="room_id" ) - next_expire_broken_caches_ms = 0 while True: try: args = store.stream_positions() args.update(typing_handler.stream_positions()) args["timeout"] = 30000 result = yield http_client.get_json(replication_url, args=args) - now_ms = clock.time_msec() - if now_ms > next_expire_broken_caches_ms: - expire_broken_caches() - next_expire_broken_caches_ms = ( - now_ms + store.BROKEN_CACHE_EXPIRY_MS - ) yield store.process_replication(result) typing_handler.process_replication(result) yield presence_handler.process_replication(result) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c2d487ff4..84993b33b3 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -41,6 +41,7 @@ STREAM_NAMES = ( ("push_rules",), ("pushers",), ("state",), + ("caches",), ) @@ -70,6 +71,7 @@ class ReplicationResource(Resource): * "backfill": Old events that have been backfilled from other servers. * "push_rules": Per user changes to push rules. * "pushers": Per user changes to their pushers. + * "caches": Cache invalidations. The API takes two additional query parameters: @@ -129,6 +131,7 @@ class ReplicationResource(Resource): push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() state_token = self.store.get_state_stream_token() + caches_token = self.store.get_cache_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -140,6 +143,7 @@ class ReplicationResource(Resource): push_rules_token, pushers_token, state_token, + caches_token, )) @request_handler() @@ -188,6 +192,7 @@ class ReplicationResource(Resource): yield self.push_rules(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams) + yield self.caches(writer, current_token, limit, request_streams) self.streams(writer, current_token, request_streams) logger.info("Replicated %d rows", writer.total) @@ -379,6 +384,20 @@ class ReplicationResource(Resource): "position", "type", "state_key", "event_id" )) + @defer.inlineCallbacks + def caches(self, writer, current_token, limit, request_streams): + current_position = current_token.caches + + caches = request_streams.get("caches") + + if caches is not None: + updated_caches = yield self.store.get_all_updated_caches( + caches, current_position, limit + ) + writer.write_header_and_rows("caches", updated_caches, ( + "position", "cache_func", "keys", "invalidation_ts" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -407,7 +426,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state" + "push_rules", "pushers", "state", "caches", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 46e43ce1c7..24c9946d6a 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,15 +14,43 @@ # limitations under the License. from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from ._slaved_id_tracker import SlavedIdTracker + +import logging + +logger = logging.getLogger(__name__) + class BaseSlavedStore(SQLBaseStore): def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(hs) + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = SlavedIdTracker( + db_conn, "cache_stream", "stream_id", + ) + else: + self._cache_id_gen = None def stream_positions(self): - return {} + pos = {} + if self._cache_id_gen: + pos["caches"] = self._cache_id_gen.get_current_token() + return pos def process_replication(self, result): + stream = result.get("caches") + if stream: + for row in stream["rows"]: + ( + position, cache_func, keys, invalidation_ts, + ) = row + + try: + getattr(self, cache_func).invalidate(tuple(keys)) + except AttributeError: + logger.warn("Got unexpected cache_func: %r", cache_func) + self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a0c029a2fc..8af492b69f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -50,6 +50,7 @@ from .openid import OpenIdStore from .client_ips import ClientIpStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from .engines import PostgresEngine from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -122,9 +123,13 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_stream", "stream_id", - ) + + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = StreamIdGenerator( + db_conn, "cache_stream", "stream_id", + ) + else: + self._cache_id_gen = None events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 02d9098ddd..e3edc2cde6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache from synapse.util.caches import intern_dict +from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -864,21 +865,43 @@ class SQLBaseStore(object): def _invalidate_cache_and_stream(self, txn, cache_func, keys): txn.call_after(cache_func.invalidate, keys) - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_after(ctx.__exit__, None, None, None) + if isinstance(self.database_engine, PostgresEngine): + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_after(ctx.__exit__, None, None, None) + + self._simple_insert_txn( + txn, + table="cache_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_func.__name__, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + } + ) - self._simple_insert_txn( - txn, - table="cache_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_func.__name__, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - } + def get_all_updated_caches(self, last_id, current_id, limit): + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts FROM cache_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit,)) + return txn.fetchall() + return self.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn ) + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying |