diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 9 | ||||
-rw-r--r-- | synapse/storage/_base.py | 162 | ||||
-rw-r--r-- | synapse/storage/directory.py | 7 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 7 | ||||
-rw-r--r-- | synapse/storage/events.py | 21 | ||||
-rw-r--r-- | synapse/storage/keys.py | 8 | ||||
-rw-r--r-- | synapse/storage/presence.py | 37 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 25 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 90 | ||||
-rw-r--r-- | synapse/storage/registration.py | 17 | ||||
-rw-r--r-- | synapse/storage/room.py | 6 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 13 | ||||
-rw-r--r-- | synapse/storage/schema/delta/22/receipts_index.sql | 18 | ||||
-rw-r--r-- | synapse/storage/state.py | 333 | ||||
-rw-r--r-- | synapse/storage/stream.py | 6 | ||||
-rw-r--r-- | synapse/storage/transactions.py | 3 |
16 files changed, 440 insertions, 322 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 99467dde02..f154b1c8ae 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, key = (user.to_string(), access_token, device_id, ip) try: - last_seen = self.client_ip_last_seen.get(*key) + last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None @@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: defer.returnValue(None) - self.client_ip_last_seen.prefill(*key + (now,)) + self.client_ip_last_seen.prefill(key, now) # It's safe not to lock here: a) no unique constraint, # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely @@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) logger.debug("Running script %s", relative_path) module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass elif ext == ".sql": # A plain old .sql file, just read and execute it logger.debug("Applying schema %s", relative_path) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8f812f0fd7..1444767a52 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,21 +17,20 @@ import logging from synapse.api.errors import StoreError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext -from synapse.util.lrucache import LruCache +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple, OrderedDict +from collections import namedtuple -import functools import sys import time import threading -DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - -class Cache(object): - - def __init__(self, name, max_entries=1000, keylen=1, lru=False): - if lru: - self.cache = LruCache(max_size=max_entries) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries - - self.name = name - self.keylen = keylen - self.sequence = 0 - self.thread = None - caches_by_name[name] = self.cache - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: - cache_counter.inc_hits(self.name) - return self.cache[keyargs] - - cache_counter.inc_misses(self.name) - raise KeyError() - - def update(self, sequence, *args): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[keyargs] = value - - def invalidate(self, *keyargs): - self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(keyargs, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - -class CacheDescriptor(object): - """ A method decorator that applies a memoizing cache around the function. - - The function is presumed to take zero or more arguments, which are used in - a tuple as the key for the cache. Hits are served directly from the cache; - misses use the function body to generate the value. - - The wrapped function has an additional member, a callable called - "invalidate". This can be used to remove individual entries from the cache. - - The wrapped function has another additional callable, called "prefill", - which can be used to insert values into the cache specifically, without - calling the calculation function. - """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=False): - self.orig = orig - - self.max_entries = max_entries - self.num_args = num_args - self.lru = lru - - def __get__(self, obj, objtype=None): - cache = Cache( - name=self.orig.__name__, - max_entries=self.max_entries, - keylen=self.num_args, - lru=self.lru, - ) - - @functools.wraps(self.orig) - @defer.inlineCallbacks - def wrapped(*keyargs): - try: - cached_result = cache.get(*keyargs[:self.num_args]) - if DEBUG_CACHES: - actual_result = yield self.orig(obj, *keyargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - - ret = yield self.orig(obj, *keyargs) - - cache.update(sequence, *keyargs[:self.num_args] + (ret,)) - - defer.returnValue(ret) - - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -def cached(max_entries=1000, num_args=1, lru=False): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru - ) - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -321,6 +167,8 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 2b2bdf8615..d92028ea43 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.errors import SynapseError @@ -104,7 +105,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -114,7 +115,7 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 45b86c94e8..25cc84eb95 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -15,7 +15,8 @@ from twisted.internet import defer -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from syutil.base64util import encode_base64 import logging @@ -362,7 +363,7 @@ class EventFederationStore(SQLBaseStore): for room_id in events_by_room: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, room_id + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) def get_backfill_events(self, room_id, event_list, limit): @@ -505,4 +506,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ed7ea38804..5b64918024 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore): if current_state: txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) self._simple_delete_txn( @@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore): if not context.rejected: txn.call_after( self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + (event.room_id, event.type, event.state_key,) + ) if event.type in [EventTypes.Name, EventTypes.Aliases]: txn.call_after( self.get_room_name_and_aliases.invalidate, - event.room_id + (event.room_id,) ) self._simple_upsert_txn( @@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) + self._get_event_cache.invalidate( + (event_id, check_redacted, get_prev_content) + ) def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): @@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore): for event_id in events: try: ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content + (event_id, check_redacted, get_prev_content,) ) if allow_rejected or not ret.rejected_reason: @@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) defer.returnValue(ev) @@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) return ev diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 940a5f7e08..ffd6daa880 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cached +from _base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer @@ -71,8 +72,7 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_all_server_verify_keys(self, server_name): rows = yield self._simple_select_list( table="server_signature_keys", @@ -132,7 +132,7 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_all_server_verify_keys.invalidate(server_name) + self.get_all_server_verify_keys.invalidate((server_name,)) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index fefcf6bce0..34ca3b9a54 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,19 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList from twisted.internet import defer class PresenceStore(SQLBaseStore): def create_presence(self, user_localpart): - return self._simple_insert( + res = self._simple_insert( table="presence", values={"user_id": user_localpart}, desc="create_presence", ) + self.get_presence_state.invalidate((user_localpart,)) + return res + def has_presence_state(self, user_localpart): return self._simple_select_one( table="presence", @@ -35,6 +39,7 @@ class PresenceStore(SQLBaseStore): desc="has_presence_state", ) + @cached(max_entries=2000) def get_presence_state(self, user_localpart): return self._simple_select_one( table="presence", @@ -43,8 +48,27 @@ class PresenceStore(SQLBaseStore): desc="get_presence_state", ) + @cachedList(get_presence_state.cache, list_name="user_localparts") + def get_presence_states(self, user_localparts): + def f(txn): + results = {} + for user_localpart in user_localparts: + res = self._simple_select_one_txn( + txn, + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["state", "status_msg", "mtime"], + allow_none=True, + ) + if res: + results[user_localpart] = res + + return results + + return self.runInteraction("get_presence_states", f) + def set_presence_state(self, user_localpart, new_state): - return self._simple_update_one( + res = self._simple_update_one( table="presence", keyvalues={"user_id": user_localpart}, updatevalues={"state": new_state["state"], @@ -53,6 +77,9 @@ class PresenceStore(SQLBaseStore): desc="set_presence_state", ) + self.get_presence_state.invalidate((user_localpart,)) + return res + def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( table="presence_allow_inbound", @@ -98,7 +125,7 @@ class PresenceStore(SQLBaseStore): updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -133,4 +160,4 @@ class PresenceStore(SQLBaseStore): "observed_user_id": observed_userid}, desc="del_presence_list", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 4cac118d17..5305b7e122 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer import logging @@ -23,8 +24,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_for_user(self, user_name): rows = yield self._simple_select_list( table=PushRuleTable.table_name, @@ -41,8 +41,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rows) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_name): results = yield self._simple_select_list( table=PushRuleEnableTable.table_name, @@ -153,11 +152,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -189,10 +188,10 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -218,8 +217,8 @@ class PushRuleStore(SQLBaseStore): desc="delete_push_rule", ) - self.get_push_rules_for_user.invalidate(user_name) - self.get_push_rules_enabled_for_user.invalidate(user_name) + self.get_push_rules_for_user.invalidate((user_name,)) + self.get_push_rules_enabled_for_user.invalidate((user_name,)) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): @@ -240,10 +239,10 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 7a6af98d98..a535063547 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches import cache_counter, caches_by_name from twisted.internet import defer -from synapse.util import unwrapFirstError - from blist import sorteddict import logging import ujson as json @@ -53,19 +53,13 @@ class ReceiptsStore(SQLBaseStore): self, room_ids, from_key ) - results = yield defer.gatherResults( - [ - self.get_linearized_receipts_for_room( - room_id, to_key, from_key=from_key - ) - for room_id in room_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + results = yield self._get_linearized_receipts_for_rooms( + room_ids, to_key, from_key=from_key + ) - defer.returnValue([ev for res in results for ev in res]) + defer.returnValue([ev for res in results.values() for ev in res]) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=3, max_entries=5000) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. @@ -125,11 +119,70 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) + @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", + num_args=3, inlineCallbacks=True) + def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + if not room_ids: + defer.returnValue({}) + + def f(txn): + if from_key: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" + ) % ( + ",".join(["?"] * len(room_ids)) + ) + args = list(room_ids) + args.extend([from_key, to_key]) + + txn.execute(sql, args) + else: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id IN (%s) AND stream_id <= ?" + ) % ( + ",".join(["?"] * len(room_ids)) + ) + + args = list(room_ids) + args.append(to_key) + + txn.execute(sql, args) + + return self.cursor_to_dict(txn) + + txn_results = yield self.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault(row["room_id"], { + "type": "m.receipt", + "room_id": row["room_id"], + "content": {}, + }) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = json.loads(row["data"]) + + results = { + room_id: [results[room_id]] if room_id in results else [] + for room_id in room_ids + } + defer.returnValue(results) + def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_max_token(self) - @cached - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_graph_receipts_for_room(self, room_id): """Get receipts for sending to remote servers. """ @@ -305,6 +358,8 @@ class _RoomStreamChangeCache(object): self._room_to_key = {} self._cache = sorteddict() self._earliest_key = None + self.name = "ReceiptsRoomChangeCache" + caches_by_name[self.name] = self._cache @defer.inlineCallbacks def get_rooms_changed(self, store, room_ids, key): @@ -318,8 +373,11 @@ class _RoomStreamChangeCache(object): result = set( self._cache[k] for k in keys[i:] ).intersection(room_ids) + + cache_counter.inc_hits(self.name) else: result = room_ids + cache_counter.inc_misses(self.name) defer.returnValue(result) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 90e2606be2..bf803f2c6e 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached class RegistrationStore(SQLBaseStore): @@ -111,16 +112,16 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens_apart_from(self, user_id, token_id): + def user_delete_access_tokens(self, user_id): yield self.runInteraction( - "user_delete_access_tokens_apart_from", - self._user_delete_access_tokens_apart_from, user_id, token_id + "user_delete_access_tokens", + self._user_delete_access_tokens, user_id ) - def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id): + def _user_delete_access_tokens(self, txn, user_id): txn.execute( - "DELETE FROM access_tokens WHERE user_id = ? AND id != ?", - (user_id, token_id) + "DELETE FROM access_tokens WHERE user_id = ?", + (user_id, ) ) @defer.inlineCallbacks @@ -131,7 +132,7 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate(r) + self.get_user_by_token.invalidate((r,)) @cached() def get_user_by_token(self, token): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 4612a8aa83..5e07b7e0e5 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks import collections import logging @@ -186,8 +187,7 @@ class RoomStore(SQLBaseStore): } ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): def f(txn): sql = ( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4db07f6fb4..8eee2dfbcc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,7 +17,8 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.constants import Membership from synapse.types import UserID @@ -54,9 +55,9 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -78,7 +79,7 @@ class RoomMemberStore(SQLBaseStore): lambda events: events[0] if events else None ) - @cached() + @cached(max_entries=5000) def get_users_in_room(self, room_id): def f(txn): @@ -154,7 +155,7 @@ class RoomMemberStore(SQLBaseStore): RoomsForUser(**r) for r in self.cursor_to_dict(txn) ] - @cached() + @cached(max_entries=5000) def get_joined_hosts_for_room(self, room_id): return self.runInteraction( "get_joined_hosts_for_room", diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/schema/delta/22/receipts_index.sql new file mode 100644 index 0000000000..b182b2b661 --- /dev/null +++ b/synapse/storage/schema/delta/22/receipts_index.sql @@ -0,0 +1,18 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( + room_id, stream_id +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 47bec65497..c9110e6304 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer @@ -44,60 +47,25 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res - - states = yield self.runInteraction( - "get_state_groups", - f, + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() - ], - consumeErrors=True, - ) - - defer.returnValue(dict(state_list)) + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) - @cached(num_args=1) - def _fetch_events_for_group(self, key, events): - return self._get_events( - events, get_prev_content=False - ).addCallback( - lambda evs: (key, evs) - ) + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state.items() + }) def _store_state_groups_txn(self, txn, event, context): return self._store_mult_state_groups_txn(txn, [(event, context)]) @@ -189,8 +157,7 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=3) def get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( @@ -206,64 +173,254 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + def _get_state_groups_from_groups(self, groups_and_types): + """Returns dictionary state_group -> state event ids + + Args: + groups_and_types (list): list of 2-tuple (`group`, `types`) + """ + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [r[0] for r in txn.fetchall()] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @cached(num_args=2, lru=True, max_entries=10000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", + num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): + """Returns mapping event_id -> state_group + """ def f(txn): - groups = set() - event_to_group = {} + results = {} for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( + results[event_id] = self._simple_select_one_onecol_txn( txn, table="event_to_state_groups", - keyvalues={"event_id": event_id}, + keyvalues={ + "event_id": event_id, + }, retcol="state_group", allow_none=True, ) - if group: - event_to_group[event_id] = group - groups.add(group) - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", + return results + + return self.runInteraction("_get_state_group_for_events", f) + + def _get_some_state_from_cache(self, group, types): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). + `missing_types` is the list of types that aren't in the cache for that + group. `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + + Args: + group: The state group to lookup + types (list): List of 2-tuples of the form (`type`, `state_key`), + where a `state_key` of `None` matches all state_keys for the + `type`. + """ + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + sentinel = object() + + def include(typ, state_key): + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + got_all = not (missing_types or types is None) + + return { + k: v for k, v in state_dict.items() + if include(k[0], k[1]) + }, missing_types, got_all + + def _get_all_state_from_cache(self, group): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool + indicating if we successfully retrieved all requests state from the + cache, if False we need to query the DB for the missing state. + + Args: + group: The state group to lookup + """ + is_all, state_dict = self._state_group_cache.get(group) + return state_dict, is_all + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + """Given list of groups returns dict of group -> list of state events + with matching types. `types` is a list of `(type, state_key)`, where + a `state_key` of None matches all state_keys. If `types` is None then + all events are returned. + """ + results = {} + missing_groups_and_types = [] + if types is not None: + for group in set(groups): + state_dict, missing_types, got_all = self._get_some_state_from_cache( + group, types + ) + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append((group, missing_types)) + else: + for group in set(groups): + state_dict, got_all = self._get_all_state_from_cache( + group ) + results[group] = state_dict - group_to_state_ids[group] = state_ids + if not got_all: + missing_groups_and_types.append((group, None)) - return event_to_group, group_to_state_ids + if not missing_groups_and_types: + defer.returnValue({ + group: { + type_tuple: event + for type_tuple, event in state.items() + if event + } + for group, state in results.items() + }) - res = yield self.runInteraction( - "annotate_events_with_state_groups", - f, - ) + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence - event_to_group, group_to_state_ids = res + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types + ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() - ], - consumeErrors=True, + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False ) - state_dict = { - group: { - (ev.type, ev.state_key): ev - for ev in state + state_events = {e.event_id: e for e in state_events} + + # Now we want to update the cache with all the things we fetched + # from the database. + for group, state_ids in group_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = {key: None for key in types} + state_dict.update(results[group]) + results[group] = state_dict + else: + state_dict = results[group] + + for event_id in state_ids: + state_event = state_events[event_id] + state_dict[(state_event.type, state_event.state_key)] = state_event + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + # Remove all the entries with None values. The None values were just + # used for bookkeeping in the cache. + for group, state_dict in results.items(): + results[group] = { + key: event for key, event in state_dict.items() if event } - for group, state in state_list - } - defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) - for event in event_ids - ]) + defer.returnValue(results) def _make_group_id(clock): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index af45fc5619..d7fe423f5a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -36,6 +36,7 @@ what sort order was used: from twisted.internet import defer from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function @@ -299,9 +300,8 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False, from_token=None): + @cachedInlineCallbacks(num_args=4) + def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): # TODO (erikj): Handle compressed feedback end_token = RoomStreamToken.parse_stream_token(end_token) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 624da4a9dc..c8c7e6591a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from collections import namedtuple |