diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 18 | ||||
-rw-r--r-- | synapse/storage/_base.py | 171 | ||||
-rw-r--r-- | synapse/storage/directory.py | 4 | ||||
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 125 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 132 | ||||
-rw-r--r-- | synapse/storage/events.py | 422 | ||||
-rw-r--r-- | synapse/storage/keys.py | 49 | ||||
-rw-r--r-- | synapse/storage/presence.py | 4 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 24 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 347 | ||||
-rw-r--r-- | synapse/storage/registration.py | 2 | ||||
-rw-r--r-- | synapse/storage/room.py | 5 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 48 | ||||
-rw-r--r-- | synapse/storage/schema/delta/14/upgrade_appservice_db.py | 2 | ||||
-rw-r--r-- | synapse/storage/schema/delta/20/dummy.sql | 1 | ||||
-rw-r--r-- | synapse/storage/schema/delta/20/pushers.py | 76 | ||||
-rw-r--r-- | synapse/storage/schema/delta/21/end_to_end_keys.sql | 34 | ||||
-rw-r--r-- | synapse/storage/schema/delta/21/receipts.sql | 38 | ||||
-rw-r--r-- | synapse/storage/signatures.py | 28 | ||||
-rw-r--r-- | synapse/storage/state.py | 119 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 38 |
21 files changed, 1292 insertions, 395 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 75af44d787..c6ce65b4cc 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -37,6 +37,9 @@ from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore +from .end_to_end_keys import EndToEndKeyStore + +from .receipts import ReceiptsStore import fnmatch @@ -51,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 19 +SCHEMA_VERSION = 21 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore, PushRuleStore, ApplicationServiceTransactionStore, EventsStore, + ReceiptsStore, + EndToEndKeyStore, ): def __init__(self, hs): @@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, key = (user.to_string(), access_token, device_id, ip) try: - last_seen = self.client_ip_last_seen.get(*key) + last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None @@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: defer.returnValue(None) - self.client_ip_last_seen.prefill(*key + (now,)) + self.client_ip_last_seen.prefill(key, now) # It's safe not to lock here: a) no unique constraint, # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely @@ -348,7 +353,12 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass elif ext == ".sql": # A plain old .sql file, just read and execute it logger.debug("Applying schema %s", relative_path) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 39884c2afe..73eea157a4 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,6 +15,7 @@ import logging from synapse.api.errors import StoreError +from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -27,6 +28,7 @@ from twisted.internet import defer from collections import namedtuple, OrderedDict import functools +import inspect import sys import time import threading @@ -55,9 +57,12 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): - def __init__(self, name, max_entries=1000, keylen=1, lru=False): + def __init__(self, name, max_entries=1000, keylen=1, lru=True): if lru: self.cache = LruCache(max_size=max_entries) self.max_entries = None @@ -81,45 +86,44 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) - raise KeyError() - def update(self, sequence, *args): + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, key, value): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + self.prefill(key, value) + def prefill(self, key, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) - self.cache[keyargs] = value + self.cache[key] = value - def invalidate(self, *keyargs): + def invalidate(self, key): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(key, tuple): + raise TypeError( + "The cache key must be a tuple not %r" % (type(key),) + ) + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 - self.cache.pop(keyargs, None) + self.cache.pop(key, None) def invalidate_all(self): self.check_thread() @@ -127,9 +131,12 @@ class Cache(object): self.cache.clear() -def cached(max_entries=1000, num_args=1, lru=False): +class CacheDescriptor(object): """ A method decorator that applies a memoizing cache around the function. + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + The function is presumed to take zero or more arguments, which are used in a tuple as the key for the cache. Hits are served directly from the cache; misses use the function body to generate the value. @@ -141,47 +148,108 @@ def cached(max_entries=1000, num_args=1, lru=False): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def wrap(orig): - cache = Cache( - name=orig.__name__, - max_entries=max_entries, - keylen=num_args, - lru=lru, + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.max_entries = max_entries + self.num_args = num_args + self.lru = lru + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + self.cache = Cache( + name=self.orig.__name__, + max_entries=self.max_entries, + keylen=self.num_args, + lru=self.lru, ) - @functools.wraps(orig) - @defer.inlineCallbacks - def wrapped(self, *keyargs): + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result = cache.get(*keyargs) + cached_result_d = self.cache.get(cache_key) + + observer = cached_result_d.observe() if DEBUG_CACHES: - actual_result = yield orig(self, *keyargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, cache_key, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence + + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + self.cache.invalidate(cache_key) + return f + + ret.addErrback(onErr) + + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, cache_key, ret) - ret = yield orig(self, *keyargs) + return ret.observe() - cache.update(sequence, *keyargs + (ret,)) + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill - defer.returnValue(ret) + obj.__dict__[self.orig.__name__] = wrapped - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill return wrapped - return wrap + +def cached(max_entries=1000, num_args=1, lru=True): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru + ) + + +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) class LoggingTransaction(object): @@ -312,13 +380,14 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine - self._stream_id_gen = StreamIdGenerator() + self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 2b2bdf8615..f3947bbe89 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py new file mode 100644 index 0000000000..325740d7d0 --- /dev/null +++ b/synapse/storage/end_to_end_keys.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _base import SQLBaseStore + + +class EndToEndKeyStore(SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): + return self._simple_upsert( + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": json_bytes, + } + ) + + def get_e2e_device_keys(self, query_list): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key json byte strings. + """ + def _get_e2e_device_keys(txn): + result = {} + for user_id, device_id in query_list: + user_result = result.setdefault(user_id, {}) + keyvalues = {"user_id": user_id} + if device_id: + keyvalues["device_id"] = device_id + rows = self._simple_select_list_txn( + txn, table="e2e_device_keys_json", + keyvalues=keyvalues, + retcols=["device_id", "key_json"] + ) + for row in rows: + user_result[row["device_id"]] = row["key_json"] + return result + return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + + def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): + def _add_e2e_one_time_keys(txn): + for (algorithm, key_id, json_bytes) in key_list: + self._simple_upsert_txn( + txn, table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": json_bytes, + } + ) + return self.runInteraction( + "add_e2e_one_time_keys", _add_e2e_one_time_keys + ) + + def count_e2e_one_time_keys(self, user_id, device_id): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + def _count_e2e_one_time_keys(txn): + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn.fetchall(): + result[algorithm] = key_count + return result + return self.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) + + def claim_e2e_one_time_keys(self, query_list): + """Take a list of one time keys out of the database""" + def _claim_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm)) + for key_id, key_json in txn.fetchall(): + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + return result + return self.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 1ba073884b..910b6598a7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -49,14 +49,22 @@ class EventFederationStore(SQLBaseStore): results = set() base_sql = ( - "SELECT auth_id FROM event_auth WHERE event_id = ?" + "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" ) front = set(event_ids) while front: new_front = set() - for f in front: - txn.execute(base_sql, (f,)) + front_list = list(front) + chunks = [ + front_list[x:x+100] + for x in xrange(0, len(front), 100) + ] + for chunk in chunks: + txn.execute( + base_sql % (",".join(["?"] * len(chunk)),), + chunk + ) new_front.update([r[0] for r in txn.fetchall()]) new_front -= results @@ -274,8 +282,7 @@ class EventFederationStore(SQLBaseStore): }, ) - def _handle_prev_events(self, txn, outlier, event_id, prev_events, - room_id): + def _handle_mult_prev_events(self, txn, events): """ For the given event, update the event edges table and forward and backward extremities tables. @@ -285,70 +292,77 @@ class EventFederationStore(SQLBaseStore): table="event_edges", values=[ { - "event_id": event_id, + "event_id": ev.event_id, "prev_event_id": e_id, - "room_id": room_id, + "room_id": ev.room_id, "is_state": False, } - for e_id, _ in prev_events + for ev in events + for e_id, _ in ev.prev_events ], ) - # Update the extremities table if this is not an outlier. - if not outlier: - for e_id, _ in prev_events: - # TODO (erikj): This could be done as a bulk insert - self._simple_delete_txn( - txn, - table="event_forward_extremities", - keyvalues={ - "event_id": e_id, - "room_id": room_id, - } - ) + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) - # We only insert as a forward extremity the new event if there are - # no other events that reference it as a prev event - query = ( - "SELECT 1 FROM event_edges WHERE prev_event_id = ?" - ) + for room_id, room_events in events_by_room.items(): + prevs = [ + e_id for ev in room_events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ] + if prevs: + txn.execute( + "DELETE FROM event_forward_extremities" + " WHERE room_id = ?" + " AND event_id in (%s)" % ( + ",".join(["?"] * len(prevs)), + ), + [room_id] + prevs, + ) - txn.execute(query, (event_id,)) + query = ( + "INSERT INTO event_forward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_edges WHERE prev_event_id = ?" + " )" + ) - if not txn.fetchone(): - query = ( - "INSERT INTO event_forward_extremities" - " (event_id, room_id)" - " VALUES (?, ?)" - ) + txn.executemany( + query, + [(ev.event_id, ev.room_id, ev.event_id) for ev in events] + ) - txn.execute(query, (event_id, room_id)) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) - txn.executemany(query, [ - (e_id, room_id, e_id, room_id, e_id, room_id, False) - for e_id, _ in prev_events - ]) + txn.executemany(query, [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ]) - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.execute(query, (event_id, room_id)) + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [(ev.event_id, ev.room_id) for ev in events] + ) + for room_id in events_by_room: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, room_id + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) def get_backfill_events(self, room_id, event_list, limit): @@ -400,10 +414,12 @@ class EventFederationStore(SQLBaseStore): keyvalues={ "event_id": event_id, }, - retcol="depth" + retcol="depth", + allow_none=True, ) - queue.put((-depth, event_id)) + if depth: + queue.put((-depth, event_id)) while not queue.empty() and len(event_results) < limit: try: @@ -489,4 +505,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 20a8d81794..5b64918024 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -23,9 +23,7 @@ from synapse.events.utils import prune_event from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function from synapse.api.constants import EventTypes -from synapse.crypto.event_signing import compute_event_reference_hash -from syutil.base64util import decode_base64 from syutil.jsonutil import encode_json from contextlib import contextmanager @@ -47,6 +45,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events class EventsStore(SQLBaseStore): @defer.inlineCallbacks + def persist_events(self, events_and_contexts, backfilled=False, + is_new_state=True): + if not events_and_contexts: + return + + if backfilled: + if not self.min_token_deferred.called: + yield self.min_token_deferred + start = self.min_token - 1 + self.min_token -= len(events_and_contexts) + 1 + stream_orderings = range(start, self.min_token, -1) + + @contextmanager + def stream_ordering_manager(): + yield stream_orderings + stream_ordering_manager = stream_ordering_manager() + else: + stream_ordering_manager = yield self._stream_id_gen.get_next_mult( + self, len(events_and_contexts) + ) + + with stream_ordering_manager as stream_orderings: + for (event, _), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream + + chunks = [ + events_and_contexts[x:x+100] + for x in xrange(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + is_new_state=is_new_state, + ) + + @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False, is_new_state=True, current_state=None): @@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore): try: with stream_ordering_manager as stream_ordering: + event.internal_metadata.stream_ordering = stream_ordering yield self.runInteraction( "persist_event", self._persist_event_txn, event=event, context=context, backfilled=backfilled, - stream_ordering=stream_ordering, is_new_state=is_new_state, current_state=current_state, ) @@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore): @log_function def _persist_event_txn(self, txn, event, context, backfilled, - stream_ordering=None, is_new_state=True, - current_state=None): - - # Remove the any existing cache entries for the event_id - txn.call_after(self._invalidate_get_event_cache, event.event_id) - + is_new_state=True, current_state=None): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) self._simple_delete_txn( @@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore): } ) - outlier = event.internal_metadata.is_outlier() + return self._persist_events_txn( + txn, + [(event, context)], + backfilled=backfilled, + is_new_state=is_new_state, + ) - if not outlier: - self._update_min_depth_for_room_txn( - txn, - event.room_id, - event.depth + @log_function + def _persist_events_txn(self, txn, events_and_contexts, backfilled, + is_new_state=True): + + # Remove the any existing cache entries for the event_ids + for event, _ in events_and_contexts: + txn.call_after(self._invalidate_get_event_cache, event.event_id) + + depth_updates = {} + for event, _ in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + depth_updates[event.room_id] = max( + event.depth, depth_updates.get(event.room_id, event.depth) ) - have_persisted = self._simple_select_one_txn( - txn, - table="events", - keyvalues={"event_id": event.event_id}, - retcols=["event_id", "outlier"], - allow_none=True, + for room_id, depth in depth_updates.items(): + self._update_min_depth_for_room_txn(txn, room_id, depth) + + txn.execute( + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( + ",".join(["?"] * len(events_and_contexts)), + ), + [event.event_id for event, _ in events_and_contexts] ) + have_persisted = { + event_id: outlier + for event_id, outlier in txn.fetchall() + } + + event_map = {} + to_remove = set() + for event, context in events_and_contexts: + # Handle the case of the list including the same event multiple + # times. The tricky thing here is when they differ by whether + # they are an outlier. + if event.event_id in event_map: + other = event_map[event.event_id] + + if not other.internal_metadata.is_outlier(): + to_remove.add(event) + continue + elif not event.internal_metadata.is_outlier(): + to_remove.add(event) + continue + else: + to_remove.add(other) - metadata_json = encode_json( - event.internal_metadata.get_dict(), - using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8") - - # If we have already persisted this event, we don't need to do any - # more processing. - # The processing above must be done on every call to persist event, - # since they might not have happened on previous calls. For example, - # if we are persisting an event that we had persisted as an outlier, - # but is no longer one. - if have_persisted: - if not outlier and have_persisted["outlier"]: - self._store_state_groups_txn(txn, event, context) + event_map[event.event_id] = event + + if event.event_id not in have_persisted: + continue + + to_remove.add(event) + + outlier_persisted = have_persisted[event.event_id] + if not event.internal_metadata.is_outlier() and outlier_persisted: + self._store_state_groups_txn( + txn, event, context, + ) + + metadata_json = encode_json( + event.internal_metadata.get_dict(), + using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8") sql = ( "UPDATE event_json SET internal_metadata = ?" @@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore): sql, (False, event.event_id,) ) - return - - if not outlier: - self._store_state_groups_txn(txn, event, context) - self._handle_prev_events( - txn, - outlier=outlier, - event_id=event.event_id, - prev_events=event.prev_events, - room_id=event.room_id, + events_and_contexts = filter( + lambda ec: ec[0] not in to_remove, + events_and_contexts ) - if event.type == EventTypes.Member: - self._store_room_member_txn(txn, event) - elif event.type == EventTypes.Name: - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Redaction: - self._store_redaction(txn, event) - - event_dict = { - k: v - for k, v in event.get_dict().items() - if k not in [ - "redacted", - "redacted_because", - ] - } + if not events_and_contexts: + return - self._simple_insert_txn( + self._store_mult_state_groups_txn(txn, [ + (event, context) + for event, context in events_and_contexts + if not event.internal_metadata.is_outlier() + ]) + + self._handle_mult_prev_events( txn, - table="event_json", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": metadata_json, - "json": encode_json( - event_dict, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8"), - }, + events=[event for event, _ in events_and_contexts], ) - content = encode_json( - event.content, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8") - - vals = { - "topological_ordering": event.depth, - "event_id": event.event_id, - "type": event.type, - "room_id": event.room_id, - "content": content, - "processed": True, - "outlier": outlier, - "depth": event.depth, - } + for event, _ in events_and_contexts: + if event.type == EventTypes.Name: + self._store_room_name_txn(txn, event) + elif event.type == EventTypes.Topic: + self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Redaction: + self._store_redaction(txn, event) - unrec = { - k: v - for k, v in event.get_dict().items() - if k not in vals.keys() and k not in [ - "redacted", - "redacted_because", - "signatures", - "hashes", - "prev_events", + self._store_room_members_txn( + txn, + [ + event + for event, _ in events_and_contexts + if event.type == EventTypes.Member ] - } + ) - vals["unrecognized_keys"] = encode_json( - unrec, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8") + def event_dict(event): + return { + k: v + for k, v in event.get_dict().items() + if k not in [ + "redacted", + "redacted_because", + ] + } - sql = ( - "INSERT INTO events" - " (stream_ordering, topological_ordering, event_id, type," - " room_id, content, processed, outlier, depth)" - " VALUES (?,?,?,?,?,?,?,?,?)" + self._simple_insert_many_txn( + txn, + table="event_json", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "internal_metadata": encode_json( + event.internal_metadata.get_dict(), + using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8"), + "json": encode_json( + event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8"), + } + for event, _ in events_and_contexts + ], ) - txn.execute( - sql, - ( - stream_ordering, event.depth, event.event_id, event.type, - event.room_id, content, True, outlier, event.depth - ) + self._simple_insert_many_txn( + txn, + table="events", + values=[ + { + "stream_ordering": event.internal_metadata.stream_ordering, + "topological_ordering": event.depth, + "depth": event.depth, + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "processed": True, + "outlier": event.internal_metadata.is_outlier(), + "content": encode_json( + event.content, using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8"), + } + for event, _ in events_and_contexts + ], ) if context.rejected: @@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore): txn, event.event_id, context.rejected ) - for hash_alg, hash_base64 in event.hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_event_content_hash_txn( - txn, event.event_id, hash_alg, hash_bytes, - ) - - for prev_event_id, prev_hashes in event.prev_events: - for alg, hash_base64 in prev_hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_prev_event_hash_txn( - txn, event.event_id, prev_event_id, alg, - hash_bytes - ) - self._simple_insert_many_txn( txn, table="event_auth", @@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore): "room_id": event.room_id, "auth_id": auth_id, } + for event, _ in events_and_contexts for auth_id, _ in event.auth_events ], ) - (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) - self._store_event_reference_hash_txn( - txn, event.event_id, ref_alg, ref_hash_bytes + self._store_event_reference_hashes_txn( + txn, [event for event, _ in events_and_contexts] ) - if event.is_state(): + state_events_and_contexts = filter( + lambda i: i[0].is_state(), + events_and_contexts, + ) + + state_values = [] + for event, context in state_events_and_contexts: vals = { "event_id": event.event_id, "room_id": event.room_id, @@ -337,51 +402,55 @@ class EventsStore(SQLBaseStore): if hasattr(event, "replaces_state"): vals["prev_state"] = event.replaces_state - self._simple_insert_txn( - txn, - "state_events", - vals, - ) + state_values.append(vals) - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": event.event_id, - "prev_event_id": e_id, - "room_id": event.room_id, - "is_state": True, - } - for e_id, h in event.prev_state - ], - ) + self._simple_insert_many_txn( + txn, + table="state_events", + values=state_values, + ) - if is_new_state and not context.rejected: - txn.call_after( - self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": event.event_id, + "prev_event_id": prev_id, + "room_id": event.room_id, + "is_state": True, + } + for event, _ in state_events_and_contexts + for prev_id, _ in event.prev_state + ], + ) - if (event.type == EventTypes.Name - or event.type == EventTypes.Aliases): + if is_new_state: + for event, _ in state_events_and_contexts: + if not context.rejected: txn.call_after( - self.get_room_name_and_aliases.invalidate, - event.room_id + self.get_current_state_for_key.invalidate, + (event.room_id, event.type, event.state_key,) ) - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) + if event.type in [EventTypes.Name, EventTypes.Aliases]: + txn.call_after( + self.get_room_name_and_aliases.invalidate, + (event.room_id,) + ) + + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } + ) return @@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) + self._get_event_cache.invalidate( + (event_id, check_redacted, get_prev_content) + ) def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): @@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore): for event_id in events: try: ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content + (event_id, check_redacted, get_prev_content,) ) if allow_rejected or not ret.rejected_reason: @@ -736,7 +806,8 @@ class EventsStore(SQLBaseStore): because = yield self.get_event( redaction_id, - check_redacted=False + check_redacted=False, + allow_none=True, ) if because: @@ -746,12 +817,13 @@ class EventsStore(SQLBaseStore): prev = yield self.get_event( ev.unsigned["replaces_state"], get_prev_content=False, + allow_none=True, ) if prev: ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) defer.returnValue(ev) @@ -808,7 +880,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) return ev diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 5bdf497b93..49b8e37cfd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from _base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -71,6 +71,24 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) + @cachedInlineCallbacks() + def get_all_server_verify_keys(self, server_name): + rows = yield self._simple_select_list( + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + }, + retcols=["key_id", "verify_key"], + desc="get_all_server_verify_keys", + ) + + defer.returnValue({ + row["key_id"]: decode_verify_key_bytes( + row["key_id"], str(row["verify_key"]) + ) + for row in rows + }) + @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): """Retrieve the NACL verification key for a given server for the given @@ -81,24 +99,14 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - sql = ( - "SELECT key_id, verify_key FROM server_signature_keys" - " WHERE server_name = ?" - " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" - ) - - rows = yield self._execute_and_decode( - "get_server_verify_keys", sql, server_name, *key_ids - ) - - keys = [] - for row in rows: - key_id = row["key_id"] - key_bytes = row["verify_key"] - key = decode_verify_key_bytes(key_id, str(key_bytes)) - keys.append(key) - defer.returnValue(keys) + keys = yield self.get_all_server_verify_keys(server_name) + defer.returnValue({ + k: keys[k] + for k in key_ids + if k in keys and keys[k] + }) + @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. @@ -109,7 +117,7 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - return self._simple_upsert( + yield self._simple_upsert( table="server_signature_keys", keyvalues={ "server_name": server_name, @@ -123,6 +131,8 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) + self.get_all_server_verify_keys.invalidate((server_name,)) + def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): """Stores the JSON bytes for a set of keys from a server @@ -152,6 +162,7 @@ class KeyStore(SQLBaseStore): "ts_valid_until_ms": ts_expires_ms, "key_json": buffer(key_json_bytes), }, + desc="store_server_keys_json", ) def get_server_keys_json(self, server_keys): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index fefcf6bce0..576cf670cc 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore): updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore): "observed_user_id": observed_userid}, desc="del_presence_list", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 4cac118d17..9b88ca7b39 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer import logging @@ -23,8 +23,7 @@ logger = logging.getLogger(__name__) class PushRuleStore(SQLBaseStore): - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_for_user(self, user_name): rows = yield self._simple_select_list( table=PushRuleTable.table_name, @@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rows) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_push_rules_enabled_for_user(self, user_name): results = yield self._simple_select_list( table=PushRuleEnableTable.table_name, @@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore): desc="delete_push_rule", ) - self.get_push_rules_for_user.invalidate(user_name) - self.get_push_rules_enabled_for_user.invalidate(user_name) + self.get_push_rules_for_user.invalidate((user_name,)) + self.get_push_rules_enabled_for_user.invalidate((user_name,)) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): @@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py new file mode 100644 index 0000000000..b79d6683ca --- /dev/null +++ b/synapse/storage/receipts.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore, cachedInlineCallbacks + +from twisted.internet import defer + +from synapse.util import unwrapFirstError + +from blist import sorteddict +import logging +import ujson as json + + +logger = logging.getLogger(__name__) + + +class ReceiptsStore(SQLBaseStore): + def __init__(self, hs): + super(ReceiptsStore, self).__init__(hs) + + self._receipts_stream_cache = _RoomStreamChangeCache() + + @defer.inlineCallbacks + def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + """Get receipts for multiple rooms for sending to clients. + + Args: + room_ids (list): List of room_ids. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + list: A list of receipts. + """ + room_ids = set(room_ids) + + if from_key: + room_ids = yield self._receipts_stream_cache.get_rooms_changed( + self, room_ids, from_key + ) + + results = yield defer.gatherResults( + [ + self.get_linearized_receipts_for_room( + room_id, to_key, from_key=from_key + ) + for room_id in room_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue([ev for res in results for ev in res]) + + @defer.inlineCallbacks + def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """Get receipts for a single room for sending to clients. + + Args: + room_ids (str): The room id. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + list: A list of receipts. + """ + def f(txn): + if from_key: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id > ? AND stream_id <= ?" + ) + + txn.execute( + sql, + (room_id, from_key, to_key) + ) + else: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id <= ?" + ) + + txn.execute( + sql, + (room_id, to_key) + ) + + rows = self.cursor_to_dict(txn) + + return rows + + rows = yield self.runInteraction( + "get_linearized_receipts_for_room", f + ) + + if not rows: + defer.returnValue([]) + + content = {} + for row in rows: + content.setdefault( + row["event_id"], {} + ).setdefault( + row["receipt_type"], {} + )[row["user_id"]] = json.loads(row["data"]) + + defer.returnValue([{ + "type": "m.receipt", + "room_id": room_id, + "content": content, + }]) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_max_token(self) + + @cachedInlineCallbacks() + def get_graph_receipts_for_room(self, room_id): + """Get receipts for sending to remote servers. + """ + rows = yield self._simple_select_list( + table="receipts_graph", + keyvalues={"room_id": room_id}, + retcols=["receipt_type", "user_id", "event_id"], + desc="get_linearized_receipts_for_room", + ) + + result = {} + for row in rows: + result.setdefault( + row["user_id"], {} + ).setdefault( + row["receipt_type"], [] + ).append(row["event_id"]) + + defer.returnValue(result) + + def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, + user_id, event_id, data, stream_id): + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + sql = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + + txn.execute(sql, (room_id, receipt_type, user_id)) + results = txn.fetchall() + + if results: + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + ) + topological_ordering = int(res["topological_ordering"]) + stream_ordering = int(res["stream_ordering"]) + + for to, so, _ in results: + if int(to) > topological_ordering: + return False + elif int(to) == topological_ordering and int(so) >= stream_ordering: + return False + + self._simple_delete_txn( + txn, + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + ) + + self._simple_insert_txn( + txn, + table="receipts_linearized", + values={ + "stream_id": stream_id, + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_id": event_id, + "data": json.dumps(data), + } + ) + + return True + + @defer.inlineCallbacks + def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + """Insert a receipt, either from local client or remote server. + + Automatically does conversion between linearized and graph + representations. + """ + if not event_ids: + return + + if len(event_ids) == 1: + linearized_event_id = event_ids[0] + else: + # we need to points in graph -> linearized form. + # TODO: Make this better. + def graph_to_linear(txn): + query = ( + "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" + " SELECT max(stream_ordering) WHERE event_id IN (%s)" + ")" + ) % (",".join(["?"] * len(event_ids))) + + txn.execute(query, [room_id] + event_ids) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + + linearized_event_id = yield self.runInteraction( + "insert_receipt_conv", graph_to_linear + ) + + stream_id_manager = yield self._receipts_id_gen.get_next(self) + with stream_id_manager as stream_id: + yield self._receipts_stream_cache.room_has_changed( + self, room_id, stream_id + ) + have_persisted = yield self.runInteraction( + "insert_linearized_receipt", + self.insert_linearized_receipt_txn, + room_id, receipt_type, user_id, linearized_event_id, + data, + stream_id=stream_id, + ) + + if not have_persisted: + defer.returnValue(None) + + yield self.insert_graph_receipt( + room_id, receipt_type, user_id, event_ids, data + ) + + max_persisted_id = yield self._stream_id_gen.get_max_token(self) + defer.returnValue((stream_id, max_persisted_id)) + + def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, + data): + return self.runInteraction( + "insert_graph_receipt", + self.insert_graph_receipt_txn, + room_id, receipt_type, user_id, event_ids, data + ) + + def insert_graph_receipt_txn(self, txn, room_id, receipt_type, + user_id, event_ids, data): + self._simple_delete_txn( + txn, + table="receipts_graph", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + ) + self._simple_insert_txn( + txn, + table="receipts_graph", + values={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": json.dumps(event_ids), + "data": json.dumps(data), + } + ) + + +class _RoomStreamChangeCache(object): + """Keeps track of the stream_id of the latest change in rooms. + + Given a list of rooms and stream key, it will give a subset of rooms that + may have changed since that key. If the key is too old then the cache + will simply return all rooms. + """ + def __init__(self, size_of_cache=10000): + self._size_of_cache = size_of_cache + self._room_to_key = {} + self._cache = sorteddict() + self._earliest_key = None + + @defer.inlineCallbacks + def get_rooms_changed(self, store, room_ids, key): + """Returns subset of room ids that have had new receipts since the + given key. If the key is too old it will just return the given list. + """ + if key > (yield self._get_earliest_key(store)): + keys = self._cache.keys() + i = keys.bisect_right(key) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(room_ids) + else: + result = room_ids + + defer.returnValue(result) + + @defer.inlineCallbacks + def room_has_changed(self, store, room_id, key): + """Informs the cache that the room has been changed at the given key. + """ + if key > (yield self._get_earliest_key(store)): + old_key = self._room_to_key.get(room_id, None) + if old_key: + key = max(key, old_key) + self._cache.pop(old_key, None) + self._cache[key] = room_id + + while len(self._cache) > self._size_of_cache: + k, r = self._cache.popitem() + self._earliest_key = max(k, self._earliest_key) + self._room_to_key.pop(r, None) + + @defer.inlineCallbacks + def _get_earliest_key(self, store): + if self._earliest_key is None: + self._earliest_key = yield store.get_max_receipt_stream_id() + self._earliest_key = int(self._earliest_key) + + defer.returnValue(self._earliest_key) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 90e2606be2..4eaa088b36 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate(r) + self.get_user_by_token.invalidate((r,)) @cached() def get_user_by_token(self, token): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 4612a8aa83..dd5bc2c8fb 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks import collections import logging @@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore): } ) - @cached() - @defer.inlineCallbacks + @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): def f(txn): sql = ( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index d36a6c18a8..9f14f38f24 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -35,38 +35,28 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def _store_room_member_txn(self, txn, event): + def _store_room_members_txn(self, txn, events): """Store a room member in the database. """ - try: - target_user_id = event.state_key - except: - logger.exception( - "Failed to parse target_user_id=%s", target_user_id - ) - raise - - logger.debug( - "_store_room_member_txn: target_user_id=%s, membership=%s", - target_user_id, - event.membership, - ) - - self._simple_insert_txn( + self._simple_insert_many_txn( txn, - "room_memberships", - { - "event_id": event.event_id, - "user_id": target_user_id, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - } + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + } + for event in events + ] ) - txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + for event in events: + txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -88,7 +78,7 @@ class RoomMemberStore(SQLBaseStore): lambda events: events[0] if events else None ) - @cached() + @cached(max_entries=5000) def get_users_in_room(self, room_id): def f(txn): @@ -164,7 +154,7 @@ class RoomMemberStore(SQLBaseStore): RoomsForUser(**r) for r in self.cursor_to_dict(txn) ] - @cached() + @cached(max_entries=5000) def get_joined_hosts_for_room(self, room_id): return self.runInteraction( "get_joined_hosts_for_room", diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 9f3a4dd4c5..61232f9757 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur): +def run_upgrade(cur, *args, **kwargs): cur.execute("SELECT id, regex FROM application_services_regex") for row in cur.fetchall(): try: diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/schema/delta/20/dummy.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/synapse/storage/schema/delta/20/dummy.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py new file mode 100644 index 0000000000..543e57bbe2 --- /dev/null +++ b/synapse/storage/schema/delta/20/pushers.py @@ -0,0 +1,76 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Main purpose of this upgrade is to change the unique key on the +pushers table again (it was missed when the v16 full schema was +made) but this also changes the pushkey and data columns to text. +When selecting a bytea column into a text column, postgres inserts +the hex encoded data, and there's no portable way of getting the +UTF-8 bytes, so we have to do it in Python. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + logger.info("Porting pushers table...") + cur.execute(""" + CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag VARCHAR(32) NOT NULL, + kind VARCHAR(8) NOT NULL, + app_id VARCHAR(64) NOT NULL, + app_display_name VARCHAR(64) NOT NULL, + device_display_name VARCHAR(128) NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang VARCHAR(8), + data TEXT, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) + ) + """) + cur.execute("""SELECT + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + FROM pushers + """) + count = 0 + for row in cur.fetchall(): + row = list(row) + row[8] = bytes(row[8]).decode("utf-8") + row[11] = bytes(row[11]).decode("utf-8") + cur.execute(database_engine.convert_param_style(""" + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), + row + ) + count += 1 + cur.execute("DROP TABLE pushers") + cur.execute("ALTER TABLE pushers2 RENAME TO pushers") + logger.info("Moved %d pushers to new table", count) diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/schema/delta/21/end_to_end_keys.sql new file mode 100644 index 0000000000..8b4a380d11 --- /dev/null +++ b/synapse/storage/schema/delta/21/end_to_end_keys.sql @@ -0,0 +1,34 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS e2e_device_keys_json ( + user_id TEXT NOT NULL, -- The user these keys are for. + device_id TEXT NOT NULL, -- Which of the user's devices these keys are for. + ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded. + key_json TEXT NOT NULL, -- The keys for the device as a JSON blob. + CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id) +); + + +CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json ( + user_id TEXT NOT NULL, -- The user this one-time key is for. + device_id TEXT NOT NULL, -- The device this one-time key is for. + algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. + key_json TEXT NOT NULL, -- The key as a JSON blob. + CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id) +); diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql new file mode 100644 index 0000000000..2f64d609fc --- /dev/null +++ b/synapse/storage/schema/delta/21/receipts.sql @@ -0,0 +1,38 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS receipts_graph( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE TABLE IF NOT EXISTS receipts_linearized ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE INDEX receipts_linearized_id ON receipts_linearized( + stream_id +); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index f051828630..4f15e534b4 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -18,6 +18,7 @@ from twisted.internet import defer from _base import SQLBaseStore from syutil.base64util import encode_base64 +from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): @@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore): txn.execute(query, (event_id, )) return {k: v for k, v in txn.fetchall()} - def _store_event_reference_hash_txn(self, txn, event_id, algorithm, - hash_bytes): + def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU Args: txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. + events (list): list of Events. """ - self._simple_insert_txn( + + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append({ + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": buffer(ref_hash_bytes), + }) + + self._simple_insert_many_txn( txn, - "event_reference_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, + table="event_reference_hashes", + values=vals, ) def _get_event_signatures_txn(self, txn, event_id): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b24de34f23..7ce51b9bdc 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -81,31 +81,41 @@ class StateStore(SQLBaseStore): f, ) - @defer.inlineCallbacks - def c(vals): - vals[:] = yield self._get_events(vals, get_prev_content=False) - - yield defer.gatherResults( + state_list = yield defer.gatherResults( [ - c(vals) - for vals in states.values() + self._fetch_events_for_group(group, vals) + for group, vals in states.items() ], consumeErrors=True, ) - defer.returnValue(states) + defer.returnValue(dict(state_list)) + + def _fetch_events_for_group(self, key, events): + return self._get_events( + events, get_prev_content=False + ).addCallback( + lambda evs: (key, evs) + ) def _store_state_groups_txn(self, txn, event, context): - if context.current_state is None: - return + return self._store_mult_state_groups_txn(txn, [(event, context)]) - state_events = dict(context.current_state) + def _store_mult_state_groups_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if context.current_state is None: + continue - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if context.state_group is not None: + state_groups[event.event_id] = context.state_group + continue + + state_events = dict(context.current_state) + + if event.is_state(): + state_events[(event.type, event.state_key)] = event - state_group = context.state_group - if not state_group: state_group = self._state_groups_id_gen.get_next_txn(txn) self._simple_insert_txn( txn, @@ -131,14 +141,19 @@ class StateStore(SQLBaseStore): for state in state_events.values() ], ) + state_groups[event.event_id] = state_group - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="event_to_state_groups", - values={ - "state_group": state_group, - "event_id": event.event_id, - }, + values=[ + { + "state_group": state_groups[event.event_id], + "event_id": event.event_id, + } + for event, context in events_and_contexts + if context.current_state is not None + ], ) @defer.inlineCallbacks @@ -173,8 +188,7 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=3) def get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( @@ -190,6 +204,65 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids): + def f(txn): + groups = set() + event_to_group = {} + for event_id in event_ids: + # TODO: Remove this loop. + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + event_to_group[event_id] = group + groups.add(group) + + group_to_state_ids = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + + group_to_state_ids[group] = state_ids + + return event_to_group, group_to_state_ids + + res = yield self.runInteraction( + "annotate_events_with_state_groups", + f, + ) + + event_to_group, group_to_state_ids = res + + state_list = yield defer.gatherResults( + [ + self._fetch_events_for_group(group, vals) + for group, vals in group_to_state_ids.items() + ], + consumeErrors=True, + ) + + state_dict = { + group: { + (ev.type, ev.state_key): ev + for ev in state + } + for group, state in state_list + } + + defer.returnValue([ + state_dict.get(event_to_group.get(event, None), None) + for event in event_ids + ]) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 89d1643f10..e956df62c7 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -72,7 +72,10 @@ class StreamIdGenerator(object): with stream_id_gen.get_next_txn(txn) as stream_id: # ... persist event ... """ - def __init__(self): + def __init__(self, table, column): + self.table = table + self.column = column + self._lock = threading.Lock() self._current_max = None @@ -108,6 +111,37 @@ class StreamIdGenerator(object): defer.returnValue(manager()) @defer.inlineCallbacks + def get_next_mult(self, store, n): + """ + Usage: + with yield stream_id_gen.get_next(store, n) as stream_ids: + # ... persist events ... + """ + if not self._current_max: + yield store.runInteraction( + "_compute_current_max", + self._get_or_compute_current_max, + ) + + with self._lock: + next_ids = range(self._current_max + 1, self._current_max + n + 1) + self._current_max += n + + for next_id in next_ids: + self._unfinished_ids.append(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_ids + finally: + with self._lock: + for next_id in next_ids: + self._unfinished_ids.remove(next_id) + + defer.returnValue(manager()) + + @defer.inlineCallbacks def get_max_token(self, store): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. @@ -126,7 +160,7 @@ class StreamIdGenerator(object): def _get_or_compute_current_max(self, txn): with self._lock: - txn.execute("SELECT MAX(stream_ordering) FROM events") + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) rows = txn.fetchall() val, = rows[0] |