From 6823fe52410db3b95df720b7955ad7b617dc7dee Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Jan 2017 18:25:13 +0000 Subject: Linearize updates to membership via PUT /state/ --- tests/rest/client/v1/test_rooms.py | 4 ++-- tests/rest/client/v1/utils.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4fe99ebc0b..6bce352c5f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase): # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=403) - yield self.leave(room=room, user=usr, expect_code=403) + yield self.join(room=room, user=usr, expect_code=404) + yield self.leave(room=room, user=usr, expect_code=404) @defer.inlineCallbacks def test_membership_private_room_perms(self): diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 17524b2e23..3bb1dd003a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -87,7 +87,10 @@ class RestTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", path, json.dumps(data) ) - self.assertEquals(expect_code, code, msg=str(response)) + self.assertEquals( + expect_code, code, + msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response) + ) self.auth_user_id = temp_id -- cgit 1.4.1 From 2fae34bd2ce152b8544d5a90fe3b35281c5fffbc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 17:46:17 +0000 Subject: Optionally measure size of cache by sum of length of values --- synapse/storage/roommember.py | 3 ++- synapse/storage/state.py | 2 +- synapse/util/caches/descriptors.py | 25 ++++++++++++++++++++----- synapse/util/caches/lrucache.py | 32 ++++++++++++++++++-------------- tests/util/test_lrucache.py | 25 +++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 21 deletions(-) (limited to 'tests') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..e63aab6ccf 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=2000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..c480743f89 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=1000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49f..d082c26b1f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -42,6 +42,13 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +def deferred_size(deferred): + if deferred.called: + return len(deferred.result) + else: + return 1 + + class Cache(object): __slots__ = ( "cache", @@ -53,10 +60,11 @@ class Cache(object): "metrics", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=deferred_size if iterable else None, ) self.name = name @@ -155,7 +163,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +177,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +213,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -421,17 +432,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +453,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c679175..00ddf38290 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,18 @@ class LruCache(object): lock = threading.Lock() + def cache_len(): + if size_callback is not None: + return sum(size_callback(node.value) for node in cache.itervalues()) + else: + return len(cache) + + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -127,22 +139,18 @@ class LruCache(object): else: callbacks = set() add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + + evict() @synchronized def cache_set_default(key, value): node = cache.get(key, None) if node is not None: + evict() # As the new node may be bigger than the old node. return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -175,10 +183,6 @@ class LruCache(object): cb() cache.clear() - @synchronized - def cache_len(): - return len(cache) - @synchronized def cache_contains(key): return key in cache @@ -190,7 +194,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 1eba5b535e..d888a64d0a 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -232,3 +232,28 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 1) + + +class LruCacheSizedTestCase(unittest.TestCase): + + def test_evict(self): + cache = LruCache(5, size_callback=len) + cache["key1"] = [0] + cache["key2"] = [1, 2] + cache["key3"] = [3] + cache["key4"] = [4] + + self.assertEquals(cache["key1"], [0]) + self.assertEquals(cache["key2"], [1, 2]) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(len(cache), 5) + + cache["key5"] = [5, 6] + + self.assertEquals(len(cache), 4) + self.assertEquals(cache.get("key1"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(cache["key5"], [5, 6]) -- cgit 1.4.1 From f2f179dce26f42ea0e691d17c60b297c63898923 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Jan 2017 15:33:34 +0000 Subject: Add ExpiringCache tests --- tests/util/test_expiring_cache.py | 84 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/util/test_expiring_cache.py (limited to 'tests') diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py new file mode 100644 index 0000000000..31d24adb8b --- /dev/null +++ b/tests/util/test_expiring_cache.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.util.caches.expiringcache import ExpiringCache + +from tests.utils import MockClock + + +class ExpiringCacheTestCase(unittest.TestCase): + + def test_get_set(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=1) + + cache["key"] = "value" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache["key"], "value") + + def test_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=2) + + cache["key"] = "value" + cache["key2"] = "value2" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache.get("key2"), "value2") + + cache["key3"] = "value3" + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), "value2") + self.assertEquals(cache.get("key3"), "value3") + + def test_iterable_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=5, iterable=True) + + cache["key"] = [1] + cache["key2"] = [2, 3] + cache["key3"] = [4, 5] + + self.assertEquals(cache.get("key"), [1]) + self.assertEquals(cache.get("key2"), [2, 3]) + self.assertEquals(cache.get("key3"), [4, 5]) + + cache["key4"] = [6, 7] + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEquals(cache.get("key4"), [6, 7]) + + def test_time_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, expiry_ms=1000) + cache.start() + + cache["key"] = 1 + clock.advance_time(0.5) + cache["key2"] = 2 + + self.assertEquals(cache.get("key"), 1) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(0.9) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(1) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) -- cgit 1.4.1 From f85b6ca494ae587731d99196020cc74d7eca012a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:18:13 +0000 Subject: Speed up cache size calculation Instead of calculating the size of the cache repeatedly, which can take a long time now that it can use a callback, instead cache the size and update that on insertion and deletion. This requires changing the cache descriptors to have two caches, one for pending deferreds and the other for the actual values. There's no reason to evict from the pending deferreds as they won't take up any more memory. --- synapse/util/caches/descriptors.py | 97 +++++++++++++++++++++++++-------- synapse/util/caches/dictionary_cache.py | 6 +- synapse/util/caches/expiringcache.py | 15 ++++- synapse/util/caches/lrucache.py | 42 ++++++++------ synapse/util/caches/treecache.py | 14 ++++- tests/storage/test__base.py | 6 +- tests/util/test_lrucache.py | 30 +++++----- 7 files changed, 148 insertions(+), 62 deletions(-) (limited to 'tests') diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index d082c26b1f..b3b2d6092d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, popped_to_iterator from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,11 +42,23 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -def deferred_size(deferred): - if deferred.called: - return len(deferred.result) - else: - return 1 +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() class Cache(object): @@ -58,13 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( max_size=max_entries, keylen=keylen, cache_type=cache_type, - size_callback=deferred_size if iterable else None, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -84,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -96,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -116,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -127,6 +178,12 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + val = self._pending_deferred_cache.pop(key, None) + if val is not None: + entry_dict, _ = val + for entry in popped_to_iterator(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -254,11 +311,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -272,7 +324,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -370,7 +422,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -393,8 +444,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79d..cb6933c61c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index b9ead9cbd5..2987c38a2d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -56,6 +56,8 @@ class ExpiringCache(object): self.iterable = iterable + self._size_estimate = 0 + def start(self): if not self._expiry_ms: # Don't bother starting the loop if things never expire @@ -70,9 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + if self.iterable: + self._size_estimate += len(value) + # Evict if there are now too many items while self._max_len and len(self) > self._max_len: - self._cache.popitem(last=False) + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -109,7 +116,9 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", @@ -118,7 +127,7 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return sum(len(value.value) for value in self._cache.itervalues()) + return self._size_estimate else: return len(self._cache) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 00ddf38290..f1de034444 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -58,12 +58,6 @@ class LruCache(object): lock = threading.Lock() - def cache_len(): - if size_callback is not None: - return sum(size_callback(node.value) for node in cache.itervalues()) - else: - return len(cache) - def evict(): while cache_len() > max_size: todelete = list_root.prev_node @@ -78,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -86,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -104,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -128,17 +137,16 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) + add_node(key, value, set(callbacks)) evict() diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea3..460e98a92d 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,24 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(popped_to_iterator(self.root)) def __len__(self): return self.size +def popped_to_iterator(d): + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in popped_to_iterator(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ab6095564a..8361dd8cee 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=2) + @cached(max_entries=20) # HACK: This makes it 2 due to cache factor def func(self, key): callcount[0] += 1 return key @@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 2) self.assertEquals(callcount2[0], 2) + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + yield a.func("foo3") self.assertEquals(callcount[0], 3) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index d888a64d0a..99aab65001 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.get("key", "value") @@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value2") @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", [m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", [m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", m1) - cache.set(("a", "2"), "value", m2) - cache.set(("b", "1"), "value", m3) - cache.set(("b", "2"), "value", m4) + cache.set(("a", "1"), "value", [m1]) + cache.set(("a", "2"), "value", [m2]) + cache.set(("b", "1"), "value", [m3]) + cache.set(("b", "2"), "value", [m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", [m1]) + cache.set("key2", "value", [m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", [m1]) + cache.set("key2", "value", [m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", m3) + cache.set("key3", "value", [m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", m1) + cache.set("key1", "value", [m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) -- cgit 1.4.1 From 9e8e236d9816ef639bdeb72cbb4de0fc29c6b120 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Jan 2017 11:48:02 +0000 Subject: Tidy up test --- tests/util/test_lrucache.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 99aab65001..dfb78cb8bd 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", [m]) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", [m]) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", [m1]) - cache.set(("a", "2"), "value", [m2]) - cache.set(("b", "1"), "value", [m3]) - cache.set(("b", "2"), "value", [m4]) + cache.set(("a", "1"), "value", callbacks=[m1]) + cache.set(("a", "2"), "value", callbacks=[m2]) + cache.set(("b", "1"), "value", callbacks=[m3]) + cache.set(("b", "2"), "value", callbacks=[m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", [m1]) - cache.set("key2", "value", [m2]) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", [m1]) - cache.set("key2", "value", [m2]) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", [m3]) + cache.set("key3", "value", callbacks=[m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,7 +227,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", [m1]) + cache.set("key1", "value", callbacks=[m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) -- cgit 1.4.1 From 5d6bad1b3c325897db81f84ebfc67ca687d851c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 13 Jan 2017 13:16:54 +0000 Subject: Optimise state resolution --- synapse/event_auth.py | 49 ++++++++-- synapse/events/__init__.py | 8 +- synapse/events/builder.py | 6 +- synapse/handlers/federation.py | 2 +- synapse/state.py | 209 +++++++++++++++++++++++++++++------------ tests/api/test_filtering.py | 5 +- tests/events/test_utils.py | 22 ++++- 7 files changed, 229 insertions(+), 72 deletions(-) (limited to 'tests') diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 983d8e9a85..3b7726a526 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -27,7 +27,7 @@ from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check(event, auth_events, do_sig_check=True): +def check(event, auth_events, do_sig_check=True, do_size_check=True): """ Checks if this event is correctly authed. Args: @@ -38,7 +38,8 @@ def check(event, auth_events, do_sig_check=True): Returns: True if the auth checks pass. """ - _check_size_limits(event) + if do_size_check: + _check_size_limits(event) if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) @@ -119,10 +120,11 @@ def check(event, auth_events, do_sig_check=True): ) return True - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Auth events: %s", + [a.event_id for a in auth_events.values()] + ) if event.type == EventTypes.Member: allowed = _is_membership_change_allowed( @@ -639,3 +641,38 @@ def get_public_keys(invite_event): public_keys.append(o) public_keys.extend(invite_event.content.get("public_keys", [])) return public_keys + + +def auth_types_for_event(event): + """Given an event, return a list of (EventType, StateKey) that may be + needed to auth the event. The returned list may be a superset of what + would actually be required depending on the full state of the room. + + Used to limit the number of events to fetch from the database to + actually auth the event. + """ + if event.type == EventTypes.Create: + return [] + + auth_types = [] + + auth_types.append((EventTypes.PowerLevels, "", )) + auth_types.append((EventTypes.Member, event.user_id, )) + auth_types.append((EventTypes.Create, "", )) + + if event.type == EventTypes.Member: + e_type = event.content["membership"] + if e_type in [Membership.JOIN, Membership.INVITE]: + auth_types.append((EventTypes.JoinRules, "", )) + + auth_types.append((EventTypes.Member, event.state_key, )) + + if e_type == Membership.INVITE: + if "third_party_invite" in event.content: + key = ( + EventTypes.ThirdPartyInvite, + event.content["third_party_invite"]["signed"]["token"] + ) + auth_types.append(key) + + return auth_types diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index da9f3ad436..e673e96cc0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -79,7 +79,6 @@ class EventBase(object): auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") content = _event_dict_property("content") - event_id = _event_dict_property("event_id") hashes = _event_dict_property("hashes") origin = _event_dict_property("origin") origin_server_ts = _event_dict_property("origin_server_ts") @@ -88,8 +87,6 @@ class EventBase(object): redacts = _event_dict_property("redacts") room_id = _event_dict_property("room_id") sender = _event_dict_property("sender") - state_key = _event_dict_property("state_key") - type = _event_dict_property("type") user_id = _event_dict_property("sender") @property @@ -162,6 +159,11 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict + self.event_id = event_dict["event_id"] + self.type = event_dict["type"] + if "state_key" in event_dict: + self.state_key = event_dict["state_key"] + super(FrozenEvent, self).__init__( frozen_dict, signatures=signatures, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 7369d70980..365fd96bd2 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import EventBase, FrozenEvent +from . import EventBase, FrozenEvent, _event_dict_property from synapse.types import EventID @@ -34,6 +34,10 @@ class EventBuilder(EventBase): internal_metadata_dict=internal_metadata_dict, ) + event_id = _event_dict_property("event_id") + state_key = _event_dict_property("state_key") + type = _event_dict_property("type") + def build(self): return FrozenEvent.from_event(self) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1021bcc405..ea89e0cf2d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state, prev_state = self.state_handler.resolve_events( + new_state = self.state_handler.resolve_events( [local_view.values(), remote_view.values()], event ) diff --git a/synapse/state.py b/synapse/state.py index 90b14e758c..294e0c2081 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -22,11 +22,10 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer -from collections import namedtuple +from collections import namedtuple, defaultdict from frozendict import frozendict import logging @@ -48,6 +47,8 @@ EVICTION_TIMEOUT_SECONDS = 60 * 60 _NEXT_STATE_ID = 1 +POWER_KEY = (EventTypes.PowerLevels, "") + def _gen_state_id(): global _NEXT_STATE_ID @@ -328,21 +329,13 @@ class StateHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) - state_map = yield self.store.get_events( - [e_id for st in state_groups_ids.values() for e_id in st.values()], - get_prev_content=False - ) - state_sets = [ - [state_map[e_id] for key, e_id in st.items() if e_id in state_map] - for st in state_groups_ids.values() - ] with Measure(self.clock, "state._resolve_events"): - new_state, _ = resolve_events( - state_sets, event_type, state_key + new_state = yield resolve_events( + state_groups_ids.values(), + state_map_factory=lambda ev_ids: self.store.get_events( + ev_ids, get_prev_content=False + ), ) - new_state = { - key: e.event_id for key, e in new_state.items() - } else: new_state = { key: e_ids.pop() for key, e_ids in state.items() @@ -390,13 +383,25 @@ class StateHandler(object): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) + state_set_ids = [{ + (ev.type, ev.state_key): ev.event_id + for ev in st + } for st in state_sets] + + state_map = { + ev.event_id: ev + for st in state_sets + for ev in st + } + with Measure(self.clock, "state._resolve_events"): - if event.is_state(): - return resolve_events( - state_sets, event.type, event.state_key - ) - else: - return resolve_events(state_sets) + new_state = resolve_events(state_set_ids, state_map) + + new_state = { + key: state_map[ev_id] for key, ev_id in new_state.items() + } + + return new_state def _ordered_events(events): @@ -406,43 +411,117 @@ def _ordered_events(events): return sorted(events, key=key_func) -def resolve_events(state_sets, event_type=None, state_key=""): +def resolve_events(state_sets, state_map_factory): """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map_factory(dict|callable): If callable, then will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. Otherwise, should be + a dict from event_id to event of all events in state_sets. + Returns - (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple - (new_state, prev_states). new_state is a map from (type, state_key) - to event. prev_states is a list of event_ids. + dict[(str, str), synapse.events.FrozenEvent] is a map from + (type, state_key) to event. """ - state = {} - for st in state_sets: - for e in st: - state.setdefault( - (e.type, e.state_key), - {} - )[e.event_id] = e - - unconflicted_state = { - k: v.values()[0] for k, v in state.items() - if len(v.values()) == 1 - } + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + if callable(state_map_factory): + return _resolve_with_state_fac( + unconflicted_state, conflicted_state, state_map_factory + ) + + state_map = state_map_factory + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + """ + unconflicted_state = dict(state_sets[0]) + conflicted_state = {} + + full_states = defaultdict( + set, + {k: set((v,)) for k, v in state_sets[0].iteritems()} + ) + + for state_set in state_sets[1:]: + for key, value in state_set.iteritems(): + ls = full_states[key] + if not ls: + ls.add(value) + unconflicted_state[key] = value + elif value not in ls: + ls.add(value) + if len(ls) == 2: + conflicted_state[key] = ls + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +@defer.inlineCallbacks +def _resolve_with_state_fac(unconflicted_state, conflicted_state, + state_map_factory): + needed_events = set( + event_id + for event_ids in conflicted_state.itervalues() + for event_id in event_ids + ) + + state_map = yield state_map_factory(needed_events) + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(auth_events.itervalues()) + new_needed_events -= needed_events + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in conflicted_state.itervalues(): + for event_id in event_ids: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state, conflicted_state, auth_events, + state_map): conflicted_state = { - k: v.values() - for k, v in state.items() - if len(v.values()) > 1 + key: [state_map[ev_id] for ev_id in event_ids] + for key, event_ids in conflicted_state.items() } - if event_type: - prev_states_events = conflicted_state.get( - (event_type, state_key), [] - ) - prev_states = [s.event_id for s in prev_states_events] - else: - prev_states = [] - auth_events = { - k: e for k, e in unconflicted_state.items() - if k[0] in AuthEventTypes + key: state_map[ev_id] + for key, ev_id in auth_events.items() } try: @@ -454,9 +533,10 @@ def resolve_events(state_sets, event_type=None, state_key=""): raise new_state = unconflicted_state - new_state.update(resolved_state) + for key, event in resolved_state.iteritems(): + new_state[key] = event.event_id - return new_state, prev_states + return new_state def _resolve_state_events(conflicted_state, auth_events): @@ -470,11 +550,10 @@ def _resolve_state_events(conflicted_state, auth_events): 4. other events. """ resolved_state = {} - power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state: - events = conflicted_state[power_key] + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[power_key] = _resolve_auth_events( + resolved_state[POWER_KEY] = _resolve_auth_events( events, auth_events) auth_events.update(resolved_state) @@ -512,14 +591,26 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] - auth_events = dict(auth_events) + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) prev_event = event except AuthError: return prev_event @@ -531,7 +622,7 @@ def _resolve_normal_events(events, auth_events): for event in _ordered_events(events): try: # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False) + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) return event except AuthError: pass diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index dcb6c5bc31..50e8607c14 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -25,10 +25,13 @@ from synapse.api.filtering import Filter from synapse.events import FrozenEvent user_localpart = "test_user" -# MockEvent = namedtuple("MockEvent", "sender type room_id") def MockEvent(**kwargs): + if "event_id" not in kwargs: + kwargs["event_id"] = "fake_event_id" + if "type" not in kwargs: + kwargs["type"] = "fake_type" return FrozenEvent(kwargs) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 29f068d1f1..dfc870066e 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event def MockEvent(**kwargs): + if "event_id" not in kwargs: + kwargs["event_id"] = "fake_event_id" + if "type" not in kwargs: + kwargs["type"] = "fake_type" return FrozenEvent(kwargs) @@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase): def test_minimal(self): self.run_test( - {'type': 'A'}, { 'type': 'A', + 'event_id': '$test:domain', + }, + { + 'type': 'A', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'B', + 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}, }, { 'type': 'B', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {'age_ts': 20}, @@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'B', + 'event_id': '$test:domain', 'unsigned': {'other_key': 'here'}, }, { 'type': 'B', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'C', + 'event_id': '$test:domain', 'content': {'things': 'here'}, }, { 'type': 'C', + 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, @@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase): self.run_test( { 'type': 'm.room.create', + 'event_id': '$test:domain', 'content': {'creator': '@2:domain', 'other_field': 'here'}, }, { 'type': 'm.room.create', + 'event_id': '$test:domain', 'content': {'creator': '@2:domain'}, 'signatures': {}, 'unsigned': {}, @@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase): self.assertEquals( self.serialize( MockEvent( + type="foo", + event_id="test", room_id="!foo:bar", content={ "foo": "bar", @@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase): [] ), { + "type": "foo", + "event_id": "test", "room_id": "!foo:bar", "content": { "foo": "bar", -- cgit 1.4.1 From 09eb08f910bd4a6077cca6ab4c3068eee55d59f3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 11:52:51 +0000 Subject: Derive current_state_events from state groups --- synapse/handlers/federation.py | 1 - synapse/state.py | 3 + synapse/storage/events.py | 188 ++++++++++++++++--------- tests/replication/slave/storage/test_events.py | 45 +++--- 4 files changed, 138 insertions(+), 99 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3f5892376..996bfd0e23 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - current_state=state, ) defer.returnValue((event_stream_id, max_stream_id)) diff --git a/synapse/state.py b/synapse/state.py index 20aaacf40f..383d32b163 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -429,6 +429,9 @@ def resolve_events(state_sets, state_map_factory): dict[(str, str), synapse.events.FrozenEvent] is a map from (type, state_key) to event. """ + if len(state_sets) == 1: + return state_sets[0] + unconflicted_state, conflicted_state = _seperate( state_sets, ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ca501932f3..0d6519f30d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore from twisted.internet import defer, reactor @@ -27,6 +27,7 @@ from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.state import resolve_events from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -71,22 +72,19 @@ class _EventPeristenceQueue(object): """ _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( - "events_and_contexts", "current_state", "backfilled", "deferred", + "events_and_contexts", "backfilled", "deferred", )) def __init__(self): self._event_persist_queues = {} self._currently_persisting_rooms = set() - def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): + def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: end_item = queue[-1] - if end_item.current_state or current_state: - # We perist events with current_state set to True one at a time - pass if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() @@ -96,7 +94,6 @@ class _EventPeristenceQueue(object): queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, backfilled=backfilled, - current_state=current_state, deferred=deferred, )) @@ -216,7 +213,6 @@ class EventsStore(SQLBaseStore): d = preserve_fn(self._event_persist_queue.add_to_queue)( room_id, evs_ctxs, backfilled=backfilled, - current_state=None, ) deferreds.append(d) @@ -229,11 +225,10 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function - def persist_event(self, event, context, current_state=None, backfilled=False): + def persist_event(self, event, context, backfilled=False): deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, - current_state=current_state, ) self._maybe_start_persisting(event.room_id) @@ -246,21 +241,10 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - if item.current_state: - for event, context in item.events_and_contexts: - # There should only ever be one item in - # events_and_contexts when current_state is - # not None - yield self._persist_event( - event, context, - current_state=item.current_state, - backfilled=item.backfilled, - ) - else: - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -294,36 +278,89 @@ class EventsStore(SQLBaseStore): for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( + + current_state_for_room = {} + if not backfilled: + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in events_by_room.items(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = set(latest_event_ids) + for event, ctx in ev_ctx_rm: + if event.internal_metadata.is_outlier(): + continue + + new_latest_event_ids.difference_update( + e_id for e_id, _ in event.prev_events + ) + new_latest_event_ids.add(event.event_id) + + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue + + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in ev_ctx_rm: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) + + state_sets.extend(group_to_state.values()) + + if not new_latest_event_ids or was_updated: + current_state_for_room[room_id] = yield resolve_events( + state_sets, + state_map_factory=lambda ev_ids: self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) + yield self.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, + current_state_for_room=current_state_for_room, ) persist_event_counter.inc_by(len(chunk)) - @_retry_on_integrity_error - @defer.inlineCallbacks - @log_function - def _persist_event(self, event, context, current_state=None, backfilled=False, - delete_existing=False): - try: - with self._stream_id_gen.get_next() as stream_ordering: - event.internal_metadata.stream_ordering = stream_ordering - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() - except _RollbackButIsFineException: - pass - @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -426,7 +463,7 @@ class EventsStore(SQLBaseStore): @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False): + delete_existing=False, current_state_for_room={}): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -436,6 +473,40 @@ class EventsStore(SQLBaseStore): If delete_existing is True then existing events will be purged from the database before insertion. This is useful when retrying due to IntegrityError. """ + for room_id, current_state in current_state_for_room.iteritems(): + txn.call_after(self._get_current_state_for_key.invalidate_all) + txn.call_after(self.get_rooms_for_user.invalidate_all) + txn.call_after(self.get_users_in_room.invalidate, (room_id,)) + + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": stream_order} + ) + + self._simple_delete_txn( + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, + ) + + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in current_state.iteritems() + ], + ) + # Ensure that we don't have the same event twice. # Pick the earliest non-outlier if there is one, else the earliest one. new_events_and_contexts = OrderedDict() @@ -798,29 +869,6 @@ class EventsStore(SQLBaseStore): # to update the current state table return - for event, _ in state_events_and_contexts: - if event.internal_metadata.is_outlier(): - # Outlier events shouldn't clobber the current state. - continue - - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) - - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) - return def _add_to_cache(self, txn, events_and_contexts): diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 44e859b5d1..38fedfe690 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -60,7 +60,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_room_members(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.replicate() yield self.check("get_rooms_for_user", (USER_ID,), []) yield self.check("get_users_in_room", (ROOM_ID,), []) @@ -95,15 +95,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): )]) yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) - # Join the room clobbering the state. - # This should remove any evidence of the other user being in the room. yield self.persist( type="m.room.member", key=USER_ID, membership="join", - reset_state=[create] ) yield self.replicate() - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) - yield self.check("get_rooms_for_user", (USER_ID_2,), []) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID]) @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): @@ -125,7 +121,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_get_current_state(self): # Create the room. - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.replicate() yield self.check( "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] @@ -151,22 +147,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): [join2] ) - # Leave the room, then rejoin the room clobbering state. - yield self.persist(type="m.room.member", key=USER_ID, membership="leave") - join3 = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - reset_state=[create] - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), - [] - ) - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), - [join3] - ) - @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID) @@ -283,6 +263,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if depth is None: depth = self.event_id + if not prev_events: + latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( + room_id + ) + prev_events = [(ev_id, {}) for ev_id in latest_event_ids] + event_dict = { "sender": sender, "type": type, @@ -309,12 +295,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_ids = { key: e.event_id for key, e in state.items() } + context = EventContext() + context.current_state_ids = state_ids + context.prev_state_ids = state_ids + elif not backfill: + state_handler = self.hs.get_state_handler() + context = yield state_handler.compute_event_context(event) else: - state_ids = None + context = EventContext() - context = EventContext() - context.current_state_ids = state_ids - context.prev_state_ids = state_ids context.push_actions = push_actions ordering = None @@ -324,7 +313,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: ordering, _ = yield self.master_store.persist_event( - event, context, current_state=reset_state + event, context, ) if ordering: -- cgit 1.4.1 From a55fa2047f813d639e2a0beed0c2d2738b0b639b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Jan 2017 15:40:04 +0000 Subject: Insert delta of current_state_events to be more efficient --- synapse/handlers/_base.py | 8 ++- synapse/replication/slave/storage/events.py | 10 ---- synapse/storage/events.py | 78 +++++++++++++++++--------- synapse/storage/state.py | 52 ----------------- tests/replication/slave/storage/test_events.py | 29 ---------- 5 files changed, 58 insertions(+), 119 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90f96209f8..e83adc8339 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -88,9 +88,13 @@ class BaseHandler(object): current_state = yield self.store.get_events( context.current_state_ids.values() ) - current_state = current_state.values() else: - current_state = yield self.store.get_current_state(event.room_id) + current_state = yield self.state_handler.get_current_state( + event.room_id + ) + + current_state = current_state.values() + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 64f18bbb3e..b3f3bf7488 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore): get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] - _get_current_state_for_key = StateStore.__dict__[ - "_get_current_state_for_key" - ] get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] @@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore): ) get_event = DataStore.get_event.__func__ get_events = DataStore.get_events.__func__ - get_current_state = DataStore.get_current_state.__func__ - get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore): def invalidate_caches_for_event(self, event, backfilled, reset_state): if reset_state: - self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() self.get_users_in_room.invalidate((event.room_id,)) @@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore): if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): return - - self._get_current_state_for_key.invalidate(( - event.room_id, event.type, event.state_key - )) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6160949f32..9f57760ab0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -476,37 +476,63 @@ class EventsStore(SQLBaseStore): """ max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering for room_id, current_state in current_state_for_room.iteritems(): - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (room_id,)) - - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) - - self._simple_delete_txn( + existing_state_rows = self._simple_select_list_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], ) - self._simple_insert_many_txn( - txn, - table="current_state_events", - values=[ - { - "event_id": ev_id, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - } - for key, ev_id in current_state.iteritems() - ], - ) + existing_events = set(row["event_id"] for row in existing_state_rows) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + if changed_events: + txn.executemany( + "DELETE FROM current_state_events WHERE event_id = ?", + [(ev_id,) for ev_id in changed_events], + ) + + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": max_stream_order} + ) + + events_to_insert = (new_events - existing_events) + to_insert = [ + (key, ev_id) for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + ] + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in to_insert + ], + ) + + members_changed = set( + row["state_key"] for row in existing_state_rows + if row["event_id"] in changed_events + and row["type"] == EventTypes.Member + ) + members_changed.update( + key[1] for key, event_id in to_insert + if key[0] == EventTypes.Member + ) + + for member in members_changed: + txn.call_after(self.get_rooms_for_user.invalidate, (member,)) + + txn.call_after(self.get_users_in_room.invalidate, (room_id,)) for room_id, new_extrem in new_forward_extremeties.items(): self._simple_delete_txn( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7d34dd03bf..d1d653327c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -232,58 +232,6 @@ class StateStore(SQLBaseStore): return count - @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type and state_key is not None: - result = yield self.get_current_state_for_key( - room_id, event_type, state_key - ) - defer.returnValue(result) - - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? " - ) - - if event_type and state_key is not None: - sql += " AND type = ? AND state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - txn.execute(sql, args) - results = txn.fetchall() - - return [r[0] for r in results] - - event_ids = yield self.runInteraction("get_current_state", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_current_state_for_key(self, room_id, event_type, state_key): - event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @cached(num_args=3) - def _get_current_state_for_key(self, room_id, event_type, state_key): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?" - ) - - args = (room_id, event_type, state_key) - txn.execute(sql, args) - results = txn.fetchall() - return [r[0] for r in results] - return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 38fedfe690..6acb8ab758 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -118,35 +118,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] ) - @defer.inlineCallbacks - def test_get_current_state(self): - # Create the room. - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] - ) - - # Join the room. - join1 = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), - [join1] - ) - - # Add some other user to the room. - join2 = yield self.persist( - type="m.room.member", key=USER_ID_2, membership="join", - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), - [join2] - ) - @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.4.1 From 2367c5568c01bc65aacc955b76ba707918b37f1e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Jan 2017 14:27:27 +0000 Subject: Add basic implementation of local device list changes --- synapse/federation/transaction_queue.py | 24 ++- synapse/handlers/device.py | 65 ++++++-- synapse/handlers/e2e_keys.py | 1 + synapse/handlers/sync.py | 13 ++ synapse/rest/client/v2_alpha/sync.py | 6 +- synapse/storage/__init__.py | 11 ++ synapse/storage/_base.py | 6 + synapse/storage/devices.py | 169 +++++++++++++++++++-- synapse/storage/end_to_end_keys.py | 23 ++- .../schema/delta/40/device_list_streams.sql | 56 +++++++ synapse/streams/events.py | 4 + synapse/types.py | 2 + tests/handlers/test_typing.py | 3 + tests/rest/client/v1/test_rooms.py | 4 +- 14 files changed, 348 insertions(+), 39 deletions(-) create mode 100644 synapse/storage/schema/delta/40/device_list_streams.sql (limited to 'tests') diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 6b3a7abb9e..65c6673a87 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -100,6 +100,7 @@ class TransactionQueue(object): self.pending_failures_by_dest = {} self.last_device_stream_id_by_dest = {} + self.last_device_list_stream_id_by_dest = {} # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -356,7 +357,7 @@ class TransactionQueue(object): success = yield self._send_new_transaction( destination, pending_pdus, pending_edus, pending_failures, device_stream_id, - should_delete_from_device_stream=bool(device_message_edus), + includes_device_messages=bool(device_message_edus), limiter=limiter, ) if not success: @@ -373,6 +374,8 @@ class TransactionQueue(object): @defer.inlineCallbacks def _get_new_device_messages(self, destination): + # TODO: Send appropriate device list messages + last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) to_device_stream_id = self.store.get_to_device_stream_token() contents, stream_id = yield self.store.get_new_device_msgs_for_remote( @@ -387,13 +390,27 @@ class TransactionQueue(object): ) for content in contents ] + + last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0) + now_stream_id, results = yield self.store.get_devices_by_remote( + destination, last_device_list + ) + edus.extend( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.device_list_update", + content=content, + ) + for content in results + ) defer.returnValue((edus, stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, pending_failures, device_stream_id, - should_delete_from_device_stream, limiter): + includes_device_messages, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -506,7 +523,8 @@ class TransactionQueue(object): success = False else: # Remove the acknowledged device messages from the database - if should_delete_from_device_stream: + # Only bother if we actually sent some device messages + if includes_device_messages: yield self.store.delete_device_msgs_for_remote( destination, device_stream_id ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index aa68755936..d92780b642 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -27,6 +28,8 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.state = hs.get_state_handler() + @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, initial_device_display_name=None): @@ -45,29 +48,29 @@ class DeviceHandler(BaseHandler): str: device id (generated if none was supplied) """ if device_id is not None: - yield self.store.store_device( + new_device = yield self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, - ignore_if_known=True, ) + if new_device: + yield self.notify_device_update(user_id, device_id) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few # times in case of a clash. attempts = 0 while attempts < 5: - try: - device_id = stringutils.random_string(10).upper() - yield self.store.store_device( - user_id=user_id, - device_id=device_id, - initial_device_display_name=initial_device_display_name, - ignore_if_known=False, - ) + device_id = stringutils.random_string(10).upper() + new_device = yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ) + if new_device: + yield self.notify_device_update(user_id, device_id) defer.returnValue(device_id) - except errors.StoreError: - attempts += 1 + attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -147,6 +150,8 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) + yield self.notify_device_update(user_id, device_id) + @defer.inlineCallbacks def update_device(self, user_id, device_id, content): """ Update the given device @@ -166,12 +171,48 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) + yield self.notify_device_update(user_id, device_id) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() else: raise + @defer.inlineCallbacks + def notify_device_update(self, user_id, device_id): + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = [r.room_id for r in rooms] + + hosts = set() + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) + + position = yield self.store.add_device_change_to_streams( + user_id, device_id, list(hosts) + ) + + yield self.notifier.on_new_event( + "device_list_key", position, rooms=room_ids, + ) + + for host in hosts: + self.federation.send_device_messages(host) + + @defer.inlineCallbacks + def get_device_list_changes(self, user_id, room_ids, from_key): + room_ids = frozenset(room_ids) + + user_ids_changed = set() + changed = yield self.store.get_user_whose_devices_changed(from_key) + for other_user_id in changed: + other_rooms = yield self.store.get_rooms_for_user(other_user_id) + if room_ids.intersection(e.room_id for e in other_rooms): + user_ids_changed.add(other_user_id) + + defer.returnValue(user_ids_changed) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b63a660c06..38c2a2d39e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -259,6 +259,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) + yield self.device_handler.notify_device_update(user_id, device_id) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c880f61685..06bf626367 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have chanegd ])): __slots__ = [] @@ -143,6 +144,7 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) self.state = hs.get_state_handler() + self.device_handler = hs.get_device_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -544,6 +546,16 @@ class SyncHandler(object): yield self._generate_sync_entry_for_to_device(sync_result_builder) + if since_token and since_token.device_list_key: + user_id = sync_config.user.to_string() + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + device_lists = yield self.device_handler.get_device_list_changes( + user_id, joined_room_ids, since_token.device_list_key + ) + else: + device_lists = [] + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, @@ -551,6 +563,7 @@ class SyncHandler(object): invited=sync_result_builder.invited, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, + device_lists=device_lists, next_batch=sync_result_builder.now_token, )) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 7199ec883a..b3d8001638 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet): ) archived = self.encode_archived( - sync_result.archived, time_now, requester.access_token_id, filter.event_fields + sync_result.archived, time_now, requester.access_token_id, + filter.event_fields, ) response_content = { "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists), + }, "presence": self.encode_presence( sync_result.presence, time_now ), diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e8495f1eb9..b9968debe5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id", + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") @@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=device_outbox_prefill, ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 963ef999d5..05374682fd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -387,6 +387,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True """ try: yield self.runInteraction( @@ -398,6 +402,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise + defer.returnValue(False) + defer.returnValue(True) @staticmethod def _simple_insert_txn(txn, table, values): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 17920d4480..b594f501f9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import ujson as json from twisted.internet import defer @@ -33,17 +34,13 @@ class DeviceStore(SQLBaseStore): user_id (str): id of user associated with the device device_id (str): id of device initial_device_display_name (str): initial displayname of the - device - ignore_if_known (bool): ignore integrity errors which mean the - device is already known + device. Ignored if device exists. Returns: - defer.Deferred - Raises: - StoreError: if ignore_if_known is False and the device was already - known + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ try: - yield self._simple_insert( + inserted = yield self._simple_insert( "devices", values={ "user_id": user_id, @@ -51,8 +48,9 @@ class DeviceStore(SQLBaseStore): "display_name": initial_device_display_name }, desc="store_device", - or_ignore=ignore_if_known, + or_ignore=True, ) + defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" " display_name=%s(%r) failed: %s", @@ -139,3 +137,156 @@ class DeviceStore(SQLBaseStore): ) defer.returnValue({d["device_id"]: d for d in devices}) + + def get_devices_by_remote(self, destination, from_stream_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + defer.returnValue((now_stream_id, [])) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id > ? AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + rows = txn.fetchall() + + if not rows: + return now_stream_id, [] + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in rows} + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND sent = ? + """ + + results = [] + for user_id, user_devices in devices.iteritems(): + txn.execute(prev_sent_id_sql, (destination, user_id, True)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, result in user_devices.iteritems(): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": prev_id, + "stream_id": stream_id, + } + + prev_id = stream_id + + key_json = result.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = result.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.setdefault(user_id, {})[device_id] = result + + return now_stream_id, results + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row["user_id"] for row in rows)) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id < ( + SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + ) + """ + txn.execute(sql, (destination, destination, stream_id,)) + + sql = """ + UPDATE device_lists_outbound_pokes SET sent = ? + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, True,)) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_id, hosts): + # device_lists_stream + # device_lists_outbound_pokes + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", self._add_device_change_txn, + user_id, device_id, hosts, stream_id, + ) + defer.returnValue(stream_id) + + def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + txn.call_after( + self._device_list_stream_cache.entity_has_changed, + user_id, stream_id, + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, stream_id, + ) + + self._simple_insert_txn( + txn, + table="device_lists_stream", + values={ + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + } + for destination in hosts + ] + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 385d607056..f82943a7a8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections - -import twisted.internet.defer +from twisted.internet import defer from ._base import SQLBaseStore @@ -33,7 +31,7 @@ class EndToEndKeyStore(SQLBaseStore): } ) - def get_e2e_device_keys(self, query_list): + def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. @@ -45,10 +43,11 @@ class EndToEndKeyStore(SQLBaseStore): return {} return self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + "get_e2e_device_keys", self._get_e2e_device_keys_txn, + query_list, include_all_devices, ) - def _get_e2e_device_keys_txn(self, txn, query_list): + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] @@ -63,23 +62,23 @@ class EndToEndKeyStore(SQLBaseStore): query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " + "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" + " %s JOIN devices d USING (user_id, device_id)" " WHERE %s" ) % ( + "FULL OUTER" if include_all_devices else "LEFT", " OR ".join("(" + q + ")" for q in query_clauses) ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - result = collections.defaultdict(dict) + result = {} for row in rows: - result[row["user_id"]][row["device_id"]] = row + result.setdefault(row["user_id"], {})[row["device_id"]] = row return result @@ -152,7 +151,7 @@ class EndToEndKeyStore(SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @twisted.internet.defer.inlineCallbacks + @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): yield self._simple_delete( table="e2e_device_keys_json", diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..61cac63bbb --- /dev/null +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -0,0 +1,56 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_list_streams_remote ( + list_id TEXT NOT NULL, + origin TEXT NOT NULL, + user_id TEXT NOT NULL, + is_full BOOLEAN NOT NULL, + ts BIGINT NOT NULL +); + +CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote( + origin, list_id, user_id +); + + +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); + + +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 4d44c3d4ca..91a59b0bae 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -44,6 +44,7 @@ class EventSources(object): def get_current_token(self): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -63,6 +64,7 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) @@ -70,6 +72,7 @@ class EventSources(object): def get_current_token_for_room(self, room_id): push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() + device_list_key = self.store.get_device_stream_token() token = StreamToken( room_key=( @@ -89,5 +92,6 @@ class EventSources(object): ), push_rules_key=push_rules_key, to_device_key=to_device_key, + device_list_key=device_list_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 3a3ab21d17..9666f9d73f 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -158,6 +158,7 @@ class StreamToken( "account_data_key", "push_rules_key", "to_device_key", + "device_list_key", )) ): _SEPARATOR = "_" @@ -195,6 +196,7 @@ class StreamToken( or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.to_device_key) < int(self.to_device_key)) + or (int(other.device_list_key) < int(self.device_list_key)) ) def copy_and_advance(self, key, new_value): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index c718d1f98f..f88d2be7c5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", + "get_devices_by_remote", ]), state_handler=self.state_handler, handlers=None, @@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase): defer.succeed(retry_timings_res) ) + self.datastore.get_devices_by_remote.return_value = (0, []) + def get_received_txn_response(*args): return defer.succeed(None) self.datastore.get_received_txn_response = get_received_txn_response diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 6bce352c5f..d746ea8568 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) @@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) -- cgit 1.4.1 From c974116f197d211ba9b42159fe61cfd5957411b5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:06:54 +0000 Subject: Implement device key caching over federation --- synapse/federation/federation_client.py | 10 + synapse/federation/federation_server.py | 3 + synapse/federation/transport/client.py | 26 +++ synapse/federation/transport/server.py | 8 + synapse/handlers/device.py | 85 +++++++-- synapse/handlers/e2e_keys.py | 40 +++- synapse/storage/devices.py | 201 +++++++++++++++++++-- synapse/storage/end_to_end_keys.py | 4 +- .../schema/delta/40/device_list_streams.sql | 20 +- tests/handlers/test_device.py | 18 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_profile.py | 1 + tests/storage/test_appservice.py | 21 ++- 13 files changed, 381 insertions(+), 57 deletions(-) (limited to 'tests') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c9175bb33d..b5bcfd705a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -126,6 +126,16 @@ class FederationClient(FederationBase): destination, content, timeout ) + @log_function + def query_user_devices(self, destination, user_id, timeout=30000): + """Query the device keys for a list of user ids hosted on a remote + server. + """ + sent_queries_counter.inc("user_devices") + return self.transport_layer.query_user_devices( + destination, user_id, timeout + ) + @log_function def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 862ccbef5d..e922b7ff4a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -416,6 +416,9 @@ class FederationServer(FederationBase): def on_query_client_keys(self, origin, content): return self.on_query_request("client_keys", content) + def on_query_user_devices(self, origin, user_id): + return self.on_query_request("user_devices", user_id) + @defer.inlineCallbacks @log_function def on_claim_client_keys(self, origin, content): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 915af34409..f49e8a2cc4 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -346,6 +346,32 @@ class TransportLayerClient(object): ) defer.returnValue(content) + @defer.inlineCallbacks + @log_function + def query_user_devices(self, destination, user_id, timeout): + """Query the devices for a user id hosted on a remote server. + + Response: + { + "stream_id": "...", + "devices": [ { ... } ] + } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/user/devices/" + user_id + + content = yield self.client.get_json( + destination=destination, + path=path, + timeout=timeout, + ) + defer.returnValue(content) + @defer.inlineCallbacks @log_function def claim_client_keys(self, destination, query_content, timeout): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 159dbd1747..c840da834c 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet): return self.handler.on_query_client_keys(origin, content) +class FederationUserDevicesQueryServlet(BaseFederationServlet): + PATH = "/user/devices/(?P[^/]*)" + + def on_GET(self, origin, content, query, user_id): + return self.handler.on_query_user_devices(origin, user_id) + + class FederationClientKeysClaimServlet(BaseFederationServlet): PATH = "/user/keys/claim" @@ -613,6 +620,7 @@ SERVLET_CLASSES = ( FederationGetMissingEventsServlet, FederationEventAuthServlet, FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, FederationClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ba4c48d590..2d66b3721a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.util.async import Linearizer from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.hs = hs self.state = hs.get_state_handler() - self.federation = hs.get_federation_sender() + self.federation_sender = hs.get_federation_sender() + self.federation = hs.get_replication_layer() + self._remote_edue_linearizer = Linearizer(name="remote_device_list") + + self.federation.register_edu_handler( + "m.device_list_update", self._incoming_device_list_update, + ) + self.federation.register_query_handler( + "user_devices", self.on_federation_query_user_devices, + ) @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, @@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few @@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) attempts += 1 @@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): @@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() @@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler): raise @defer.inlineCallbacks - def notify_device_update(self, user_id, device_id): + def notify_device_update(self, user_id, device_ids): rooms = yield self.store.get_rooms_for_user(user_id) room_ids = [r.room_id for r in rooms] hosts = set() - for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) - hosts.update(get_domain_from_id(u) for u in users) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) position = yield self.store.add_device_change_to_streams( - user_id, device_id, list(hosts) + user_id, device_ids, list(hosts) ) yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, ) + logger.info("Sending device list update notif to: %r", hosts) for host in hosts: - self.federation.send_device_messages(host) + self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def get_device_list_changes(self, user_id, room_ids, from_key): @@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler): defer.returnValue(user_ids_changed) + @defer.inlineCallbacks + def _incoming_device_list_update(self, origin, edu_content): + user_id = edu_content["user_id"] + device_id = edu_content["device_id"] + stream_id = edu_content["stream_id"] + prev_ids = edu_content.get("prev_id", []) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + return + + logger.info("Got edu: %r", edu_content) + + with (yield self._remote_edue_linearizer.queue(user_id)): + resync = True + if len(prev_ids) == 1: + extremity = yield self.store.get_device_list_remote_extremity(user_id) + logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) + if str(extremity) == str(prev_ids[0]): + resync = False + + if resync: + result = yield self.federation.query_user_devices(origin, user_id) + stream_id = result["stream_id"] + devices = result["devices"] + yield self.store.update_remote_device_list_cache( + user_id, devices, stream_id, + ) + device_ids = [device["device_id"] for device in devices] + yield self.notify_device_update(user_id, device_ids) + else: + content = dict(edu_content) + for key in ("user_id", "device_id", "stream_id", "prev_ids"): + content.pop(key, None) + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + yield self.notify_device_update(user_id, [device_id]) + + @defer.inlineCallbacks + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 38c2a2d39e..832998a6d3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -73,8 +73,7 @@ class E2eKeysHandler(object): if self.is_mine_id(user_id): local_query[user_id] = device_ids else: - domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_ids + remote_queries[user_id] = device_ids # do the queries failures = {} @@ -85,9 +84,40 @@ class E2eKeysHandler(object): if user_id in local_query: results[user_id] = keys + remote_queries_not_in_cache = {} + if remote_queries: + query_list = [] + for user_id, device_ids in remote_queries.iteritems(): + if device_ids: + query_list.extend((user_id, device_id) for device_id in device_ids) + else: + query_list.append((user_id, None)) + + user_ids_not_in_cache, remote_results = ( + yield self.store.get_user_devices_from_cache( + query_list + ) + ) + for user_id, devices in remote_results.iteritems(): + user_devices = results.setdefault(user_id, {}) + for device_id, device in devices.iteritems(): + keys = device.get("keys", None) + device_display_name = device.get("device_display_name", None) + if keys: + result = dict(keys) + unsigned = result.setdefault("unsigned", {}) + if device_display_name: + unsigned["device_display_name"] = device_display_name + user_devices[device_id] = result + + for user_id in user_ids_not_in_cache: + domain = get_domain_from_id(user_id) + r = remote_queries_not_in_cache.setdefault(domain, {}) + r[user_id] = remote_queries[user_id] + @defer.inlineCallbacks def do_remote_query(destination): - destination_query = remote_queries[destination] + destination_query = remote_queries_not_in_cache[destination] try: limiter = yield get_retry_limiter( destination, self.clock, self.store @@ -119,7 +149,7 @@ class E2eKeysHandler(object): yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) - for destination in remote_queries + for destination in remote_queries_not_in_cache ])) defer.returnValue({ @@ -259,7 +289,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) - yield self.device_handler.notify_device_update(user_id, device_id) + yield self.device_handler.notify_device_update(user_id, [device_id]) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9628e2ff75..8ee3119db2 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + def get_device_list_remote_extremity(self, user_id): + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() @@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore): txn.execute(prev_sent_id_sql, (destination, user_id, True)) rows = txn.fetchall() prev_id = rows[0][0] - for device_id, result in user_devices.iteritems(): + for device_id, device in user_devices.iteritems(): stream_id = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore): prev_id = stream_id - key_json = result.get("key_json", None) + key_json = device.get("key_json", None) if key_json: result["keys"] = json.loads(key_json) - device_display_name = result.get("device_display_name", None) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + def get_user_devices_from_cache(self, query_list): + return self.runInteraction( + "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, + query_list, + ) + + def _get_user_devices_from_cache_txn(self, txn, query_list): + user_ids = {user_id for user_id, _ in query_list} + + user_ids_in_cache = set() + for user_id in user_ids: + stream_ids = self._simple_select_onecol_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + retcol="stream_id", + ) + if stream_ids: + user_ids_in_cache.add(user_id) + + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + content = self._simple_select_one_onecol_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + ) + results.setdefault(user_id, {})[device_id] = json.loads(content) + else: + devices = self._simple_select_list_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + ) + results[user_id] = { + device["device_id"]: json.loads(device["content"]) + for device in devices + } + user_ids_in_cache.discard(user_id) + + return user_ids_not_in_cache, results + + def get_devices_with_keys_by_user(self, user_id): + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + for user_id, user_devices in devices.iteritems(): + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, @@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_id, hosts): + def add_device_change_to_streams(self, user_id, device_ids, hosts): # device_lists_stream # device_lists_outbound_pokes with self._device_list_id_gen.get_next() as stream_id: yield self.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, - user_id, device_id, hosts, stream_id, + user_id, device_ids, hosts, stream_id, ) defer.returnValue(stream_id) - def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_id, @@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="device_lists_stream", - values={ - "stream_id": stream_id, - "user_id": user_id, - "device_id": device_id, - } + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] ) self._simple_insert_many_txn( @@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore): "sent": False, } for destination in hosts + for device_id in device_ids ] ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index f82943a7a8..a915c790ff 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore): query_params = [] for (user_id, device_id) in query_list: - query_clause = "k.user_id = ?" + query_clause = "user_id = ?" query_params.append(user_id) if device_id: - query_clause += " AND k.device_id = ?" + query_clause += " AND device_id = ?" query_params.append(device_id) query_clauses.append(query_clause) diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql index 61cac63bbb..d1051c6ddf 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -13,18 +13,6 @@ * limitations under the License. */ -CREATE TABLE device_list_streams_remote ( - list_id TEXT NOT NULL, - origin TEXT NOT NULL, - user_id TEXT NOT NULL, - is_full BOOLEAN NOT NULL, - ts BIGINT NOT NULL -); - -CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote( - origin, list_id, user_id -); - CREATE TABLE device_lists_remote_cache ( user_id TEXT NOT NULL, @@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache ( CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + CREATE TABLE device_lists_stream ( stream_id BIGINT NOT NULL, user_id TEXT NOT NULL, diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 85a970a6c9..2eaaa8253c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver(handlers=None) - self.handler = synapse.handlers.device.DeviceHandler(hs) + hs = yield utils.setup_test_homeserver() + self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): res = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="display name" ) self.assertEqual(res, "fco") - dev = yield self.handler.store.get_device("boris", "fco") + dev = yield self.handler.store.get_device("@boris:foo", "fco") self.assertEqual(dev["display_name"], "display name") @defer.inlineCallbacks def test_device_is_preserved_if_exists(self): res1 = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="display name" ) self.assertEqual(res1, "fco") res2 = yield self.handler.check_device_registered( - user_id="boris", + user_id="@boris:foo", device_id="fco", initial_device_display_name="new display name" ) self.assertEqual(res2, "fco") - dev = yield self.handler.store.get_device("boris", "fco") + dev = yield self.handler.store.get_device("@boris:foo", "fco") self.assertEqual(dev["display_name"], "display name") @defer.inlineCallbacks def test_device_id_is_made_up_if_unspecified(self): device_id = yield self.handler.check_device_registered( - user_id="theresa", + user_id="@theresa:foo", device_id=None, initial_device_display_name="display" ) - dev = yield self.handler.store.get_device("theresa", device_id) + dev = yield self.handler.store.get_device("@theresa:foo", device_id) self.assertEqual(dev["display_name"], "display") @defer.inlineCallbacks diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5d602c1531..ceb9aa5765 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase): def setUp(self): self.mock_federation = Mock(spec=[ "make_query", + "register_edu_handler", ]) self.query_handlers = {} diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f1f664275f..979cebf600 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase): def setUp(self): self.mock_federation = Mock(spec=[ "make_query", + "register_edu_handler", ]) self.query_handlers = {} diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 9ff1abcd80..9e98d0e330 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) + hs = yield setup_test_homeserver( + config=config, + federation_sender=Mock(), + replication_layer=Mock(), + ) self.as_token = "token1" self.as_url = "some_url" @@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): event_cache_size=1, password_providers=[], ) - hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) + hs = yield setup_test_homeserver( + config=config, + federation_sender=Mock(), + replication_layer=Mock(), + ) self.db_pool = hs.get_db_pool() self.as_list = [ @@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) ApplicationServiceStore(hs) @@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) with self.assertRaises(ConfigError) as cm: @@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, datastore=Mock(), - federation_sender=Mock() + federation_sender=Mock(), + replication_layer=Mock(), ) with self.assertRaises(ConfigError) as cm: -- cgit 1.4.1 From b3e1f2aa7a1d583119378bb938ad476e72cc35ac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 17:16:24 +0000 Subject: Fix unit tests --- tests/storage/test_end_to_end_keys.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'tests') diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 453bc61438..bfa6294250 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -35,6 +35,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = '{ "key": "value" }' + yield self.store.store_device( + "user", "device", None + ) + yield self.store.set_e2e_device_keys( "user", "device", now, json) @@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): def test_multiple_devices(self): now = 1470174257070 + yield self.store.store_device( + "user1", "device1", None + ) + yield self.store.store_device( + "user1", "device2", None + ) + yield self.store.store_device( + "user2", "device1", None + ) + yield self.store.store_device( + "user2", "device2", None + ) + yield self.store.set_e2e_device_keys( "user1", "device1", now, 'json11') yield self.store.set_e2e_device_keys( -- cgit 1.4.1 From c7a26b7c3243c3187a8d12060cb2d2a02d318260 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Jan 2017 17:11:24 +0000 Subject: Fix unit tests --- synapse/handlers/e2e_keys.py | 2 +- synapse/storage/end_to_end_keys.py | 12 ++++++++++-- tests/storage/test_end_to_end_keys.py | 8 ++++---- 3 files changed, 15 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 49b277a1af..e40495d1ab 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -194,7 +194,7 @@ class E2eKeysHandler(object): # "unsigned" section for user_id, device_keys in results.items(): for device_id, device_info in device_keys.items(): - r = json.loads(device_info["key_json"]) + r = dict(device_info["keys"]) r["unsigned"] = {} display_name = device_info["device_display_name"] if display_name is not None: diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index aa54d7637c..2040e022fa 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -15,6 +15,7 @@ from twisted.internet import defer from canonicaljson import encode_canonical_json +import ujson as json from ._base import SQLBaseStore @@ -59,6 +60,7 @@ class EndToEndKeyStore(SQLBaseStore): "set_e2e_device_keys", _set_e2e_device_keys_txn ) + @defer.inlineCallbacks def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: @@ -70,13 +72,19 @@ class EndToEndKeyStore(SQLBaseStore): dict containing "key_json", "device_display_name". """ if not query_list: - return {} + defer.returnValue({}) - return self.runInteraction( + results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, include_all_devices, ) + for user_id, device_keys in results.iteritems(): + for device_id, device_info in device_keys.iteritems(): + device_info["keys"] = json.loads(device_info.pop("key_json")) + + defer.returnValue(results) + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index bfa6294250..84ce492a2c 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -33,7 +33,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_key_without_device_name(self): now = 1470174257070 - json = '{ "key": "value" }' + json = {"key": "value"} yield self.store.store_device( "user", "device", None @@ -47,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset({ - "key_json": json, + "keys": json, "device_display_name": None, }, dev) @defer.inlineCallbacks def test_get_key_with_device_name(self): now = 1470174257070 - json = '{ "key": "value" }' + json = {"key": "value"} yield self.store.set_e2e_device_keys( "user", "device", now, json) @@ -67,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset({ - "key_json": json, + "keys": json, "device_display_name": "display_name", }, dev) -- cgit 1.4.1 From 692daf6f5439c3c4852934f3bc950ccac2ec6d92 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 Jan 2017 16:10:16 +0000 Subject: Remote membership tests for replication This is because it now relies of the caches stream, which only works on postgres. We are trying to test with sqlite. --- tests/replication/slave/storage/test_events.py | 43 -------------------------- 1 file changed, 43 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 6acb8ab758..105e1228bb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,49 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks - def test_room_members(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID,), []) - yield self.check("get_users_in_room", (ROOM_ID,), []) - - # Join the room. - join = yield self.persist(type="m.room.member", key=USER_ID, membership="join") - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser( - room_id=ROOM_ID, - sender=USER_ID, - membership="join", - event_id=join.event_id, - stream_ordering=join.internal_metadata.stream_ordering, - )]) - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) - - # Leave the room. - yield self.persist(type="m.room.member", key=USER_ID, membership="leave") - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID,), []) - yield self.check("get_users_in_room", (ROOM_ID,), []) - - # Add some other user to the room. - join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join") - yield self.replicate() - yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser( - room_id=ROOM_ID, - sender=USER_ID, - membership="join", - event_id=join.event_id, - stream_ordering=join.internal_metadata.stream_ordering, - )]) - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) - - yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - ) - yield self.replicate() - yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID]) - @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.4.1 From 51adaac953c00ee59101a71de6162cde4a0e0a86 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Feb 2017 10:53:36 +0000 Subject: Fix email push in pusher worker This was broken when device list updates were implemented, as Mailer could no longer instantiate an AuthHandler due to a dependency on federation sending. --- synapse/handlers/auth.py | 80 ++++++++++++++++++-------------- synapse/handlers/register.py | 10 ++-- synapse/push/mailer.py | 4 +- synapse/rest/client/v1/login.py | 5 +- synapse/rest/client/v2_alpha/register.py | 3 +- synapse/server.py | 6 ++- tests/handlers/test_auth.py | 12 ++--- tests/handlers/test_register.py | 7 +-- 8 files changed, 70 insertions(+), 57 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 221d7ea7a2..fffba34383 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -65,6 +65,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -529,37 +530,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): - access_token = self.generate_access_token(user_id) + access_token = self.macaroon_gen.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) - def generate_access_token(self, user_id, extra_caveats=None): - extra_caveats = extra_caveats or [] - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = access") - # Include a nonce, to make sure that each login gets a different - # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - for caveat in extra_caveats: - macaroon.add_first_party_caveat(caveat) - return macaroon.serialize() - - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = login") - now = self.hs.get_clock().time_msec() - expiry = now + duration_in_ms - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - return macaroon.serialize() - - def generate_delete_pusher_token(self, user_id): - macaroon = self._generate_base_macaroon(user_id) - macaroon.add_first_party_caveat("type = delete_pusher") - return macaroon.serialize() - def validate_short_term_login_token_and_get_user_id(self, login_token): auth_api = self.hs.get_auth() try: @@ -570,15 +545,6 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - def _generate_base_macaroon(self, user_id): - macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, - identifier="key", - key=self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - return macaroon - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) @@ -673,6 +639,48 @@ class AuthHandler(BaseHandler): return False +class MacaroonGeneartor(object): + def __init__(self, hs): + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + self.macaroon_secret_key = hs.config.macaroon_secret_key + + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = access") + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) + return macaroon.serialize() + + def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = login") + now = self.clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location=self.server_name, + identifier="key", + key=self.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + + class _AccountHandler(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 286f0cef0a..03c6a85fc6 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id = None + self.macaroon_gen = hs.get_macaroon_generator() + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): @@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = self.macaroon_gen.generate_access_token(user_id) if need_register: yield self.store.register( diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ce2d31fb98..62d794f22b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -81,7 +81,7 @@ class Mailer(object): def __init__(self, hs, app_name): self.hs = hs self.store = self.hs.get_datastore() - self.auth_handler = self.hs.get_auth_handler() + self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name @@ -466,7 +466,7 @@ class Mailer(object): def make_unsubscribe_link(self, user_id, app_id, email_address): params = { - "access_token": self.auth_handler.generate_delete_pusher_token(user_id), + "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, "pushkey": email_address, } diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0c9cdff3b8..72057f1b0c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_GET(self, request): @@ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet): yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(registered_user_id) + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id + ) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3e7a285e10..ccca5a12d5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet): self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler self.device_handler = hs.get_device_handler() + self.macaroon_gen = hs.get_macaroon_generator() @defer.inlineCallbacks def on_POST(self, request): @@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet): user_id, device_id, initial_display_name ) - access_token = self.auth_handler.generate_access_token( + access_token = self.macaroon_gen.generate_access_token( user_id, ["guest = true"] ) defer.returnValue((200, { diff --git a/synapse/server.py b/synapse/server.py index 0bfb411269..c577032041 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler +from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler @@ -131,6 +131,7 @@ class HomeServer(object): 'federation_transport_client', 'federation_sender', 'receipts_handler', + 'macaroon_generator', ] def __init__(self, hostname, **kwargs): @@ -213,6 +214,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_macaroon_generator(self): + return MacaroonGeneartor(self) + def build_device_handler(self): return DeviceHandler(self) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 9d013e5ca7..1822dcf1e0 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver(handlers=None) self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler + self.macaroon_generator = self.hs.get_macaroon_generator() def test_token_is_a_macaroon(self): - self.hs.config.macaroon_secret_key = "this key is a huge secret" - - token = self.auth_handler.generate_access_token("some_user") + token = self.macaroon_generator.generate_access_token("some_user") # Check that we can parse the thing with pymacaroons macaroon = pymacaroons.Macaroon.deserialize(token) # The most basic of sanity checks @@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.clock.now = 5000 - token = self.auth_handler.generate_access_token("a_user") + token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): @@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase): def test_short_term_login_token_gives_user_id(self): self.hs.clock.now = 1000 - token = self.auth_handler.generate_short_term_login_token( + token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) @@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase): ) def test_short_term_login_token_cannot_replace_user_id(self): - token = self.auth_handler.generate_short_term_login_token( + token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a4380c48b4..c8cf9a63ec 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase): handlers=None, http_client=None, expire_access_token=True) - self.auth_handler = Mock( + self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret')) + self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() - self.mock_handler = Mock(spec=[ - "generate_access_token", - ]) - self.hs.get_auth_handler = Mock(return_value=self.auth_handler) @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): -- cgit 1.4.1