diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/_base.py | 76 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 4 | ||||
-rw-r--r-- | synapse/storage/events.py | 32 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 4 | ||||
-rw-r--r-- | synapse/storage/transactions.py | 6 |
5 files changed, 97 insertions, 25 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7c3cf03c8..ee5587c721 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -31,7 +31,9 @@ import functools import simplejson as json import sys import time +import threading +DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -68,9 +70,20 @@ class Cache(object): self.name = name self.keylen = keylen - + self.sequence = 0 + self.thread = None caches_by_name[name] = self.cache + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + def get(self, *keyargs): if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) @@ -82,6 +95,13 @@ class Cache(object): cache_counter.inc_misses(self.name) raise KeyError() + def update(self, sequence, *args): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(*args) + def prefill(self, *args): # because I can't *keyargs, value keyargs = args[:-1] value = args[-1] @@ -96,9 +116,12 @@ class Cache(object): self.cache[keyargs] = value def invalidate(self, *keyargs): + self.check_thread() if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) - + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 self.cache.pop(keyargs, None) @@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False): @defer.inlineCallbacks def wrapped(self, *keyargs): try: - defer.returnValue(cache.get(*keyargs)) + cached_result = cache.get(*keyargs) + if DEBUG_CACHES: + actual_result = yield orig(self, *keyargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + orig.__name__, keyargs, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) except KeyError: + # Get the sequence number of the cache before reading from the + # database so that we can tell if the cache is invalidated + # while the SELECT is executing (SYN-369) + sequence = cache.sequence + ret = yield orig(self, *keyargs) - cache.prefill(*keyargs + (ret,)) + cache.update(sequence, *keyargs + (ret,)) defer.returnValue(ret) @@ -147,12 +185,20 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" - __slots__ = ["txn", "name", "database_engine"] + __slots__ = ["txn", "name", "database_engine", "after_callbacks"] - def __init__(self, txn, name, database_engine): + def __init__(self, txn, name, database_engine, after_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) + object.__setattr__(self, "after_callbacks", after_callbacks) + + def call_after(self, callback, *args): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + self.after_callbacks.append((callback, args)) def __getattr__(self, name): return getattr(self.txn, name) @@ -299,6 +345,8 @@ class SQLBaseStore(object): start_time = time.time() * 1000 + after_callbacks = [] + def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: if self.database_engine.is_connection_closed(conn): @@ -323,10 +371,10 @@ class SQLBaseStore(object): while True: try: txn = conn.cursor() - return func( - LoggingTransaction(txn, name, self.database_engine), - *args, **kwargs + txn = LoggingTransaction( + txn, name, self.database_engine, after_callbacks ) + return func(txn, *args, **kwargs) except self.database_engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. @@ -375,6 +423,8 @@ class SQLBaseStore(object): result = yield self._db_pool.runWithConnection( inner_func, *args, **kwargs ) + for after_callback, after_args in after_callbacks: + after_callback(*after_args) defer.returnValue(result) def cursor_to_dict(self, cursor): @@ -453,6 +503,14 @@ class SQLBaseStore(object): if not values: return + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. keys, vals = zip(*[ zip( *(sorted(i.items(), key=lambda kv: kv[0])) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 36b1feac60..74b4e23590 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -332,7 +332,9 @@ class EventFederationStore(SQLBaseStore): ) txn.execute(query) - self.get_latest_event_ids_in_room.invalidate(room_id) + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, room_id + ) def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 34bd49cfe9..38395c66ab 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -93,7 +93,7 @@ class EventsStore(SQLBaseStore): current_state=None): # Remove the any existing cache entries for the event_id - self._invalidate_get_event_cache(event.event_id) + txn.call_after(self._invalidate_get_event_cache, event.event_id) if stream_ordering is None: with self._stream_id_gen.get_next_txn(txn) as stream_ordering: @@ -113,19 +113,24 @@ class EventsStore(SQLBaseStore): keyvalues={"room_id": event.room_id}, ) - self._simple_insert_many_txn( - txn, - "current_state_events", - [ + for s in current_state: + if s.type == EventTypes.Member: + txn.call_after( + self.get_rooms_for_user.invalidate, s.state_key + ) + txn.call_after( + self.get_joined_hosts_for_room.invalidate, s.room_id + ) + self._simple_insert_txn( + txn, + "current_state_events", { "event_id": s.event_id, "room_id": s.room_id, "type": s.type, "state_key": s.state_key, } - for s in current_state - ], - ) + ) outlier = event.internal_metadata.is_outlier() @@ -261,7 +266,9 @@ class EventsStore(SQLBaseStore): ) if context.rejected: - self._store_rejections_txn(txn, event.event_id, context.rejected) + self._store_rejections_txn( + txn, event.event_id, context.rejected + ) for hash_alg, hash_base64 in event.hashes.items(): hash_bytes = decode_base64(hash_base64) @@ -273,7 +280,8 @@ class EventsStore(SQLBaseStore): for alg, hash_base64 in prev_hashes.items(): hash_bytes = decode_base64(hash_base64) self._store_prev_event_hash_txn( - txn, event.event_id, prev_event_id, alg, hash_bytes + txn, event.event_id, prev_event_id, alg, + hash_bytes ) self._simple_insert_many_txn( @@ -340,9 +348,11 @@ class EventsStore(SQLBaseStore): } ) + return + def _store_redaction(self, txn, event): # invalidate the cache for the redacted event - self._invalidate_get_event_cache(event.redacts) + txn.call_after(self._invalidate_get_event_cache, event.redacts) txn.execute( "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", (event.event_id, event.redacts) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 09fb77a194..839c74f63a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -64,8 +64,8 @@ class RoomMemberStore(SQLBaseStore): } ) - self.get_rooms_for_user.invalidate(target_user_id) - self.get_joined_hosts_for_room.invalidate(event.room_id) + txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) + txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 89dd7d8947..624da4a9dc 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -17,6 +17,7 @@ from ._base import SQLBaseStore, cached from collections import namedtuple +from syutil.jsonutil import encode_canonical_json import logging logger = logging.getLogger(__name__) @@ -82,7 +83,7 @@ class TransactionStore(SQLBaseStore): "transaction_id": transaction_id, "origin": origin, "response_code": code, - "response_json": response_dict, + "response_json": buffer(encode_canonical_json(response_dict)), }, or_ignore=True, desc="set_received_txn_response", @@ -161,7 +162,8 @@ class TransactionStore(SQLBaseStore): return self.runInteraction( "delivered_txn", self._delivered_txn, - transaction_id, destination, code, response_dict + transaction_id, destination, code, + buffer(encode_canonical_json(response_dict)), ) def _delivered_txn(self, txn, transaction_id, destination, |