diff options
31 files changed, 405 insertions, 258 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 68e9d8c671..9106134b46 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -30,6 +30,7 @@ Changes in synapse v0.18.6 (2017-01-06) Bug fixes: * Fix bug when checking if a guest user is allowed to join a room (PR #1772) + Thanks to Patrik Oldsberg for diagnosing and the fix! Changes in synapse v0.18.6-rc3 (2017-01-05) diff --git a/README.rst b/README.rst index ba21c52ae7..77e0b470a3 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,7 @@ Installing prerequisites on openSUSE:: python-devel libffi-devel libopenssl-devel libjpeg62-devel Installing prerequisites on OpenBSD:: + doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ libxslt diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index c1379fdd7d..1900930053 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -76,8 +76,7 @@ class AppserviceServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -87,9 +86,6 @@ class AppserviceServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -109,11 +105,7 @@ class AppserviceServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index b5e1d659e6..4d081eccd1 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -90,8 +90,7 @@ class ClientReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -110,9 +109,6 @@ class ClientReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -132,11 +128,7 @@ class ClientReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index c6810b83db..90a4816753 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,8 +86,7 @@ class FederationReaderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -101,9 +100,6 @@ class FederationReaderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -123,11 +119,7 @@ class FederationReaderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 23aae8a09c..ec06620efb 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -82,8 +82,7 @@ class FederationSenderServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -93,9 +92,6 @@ class FederationSenderServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -115,11 +111,7 @@ class FederationSenderServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6c69ccd7e2..e0b87468fe 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -107,8 +107,7 @@ def build_resource_for_web_client(hs): class SynapseHomeServer(HomeServer): def _listener_http(self, config, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] tls = listener_config.get("tls", False) site_tag = listener_config.get("tag", port) @@ -175,9 +174,6 @@ class SynapseHomeServer(HomeServer): root_resource = create_resource_tree(resources, root_resource) - if bind_address is not None: - bind_addresses.append(bind_address) - if tls: for address in bind_addresses: reactor.listenSSL( @@ -212,11 +208,7 @@ class SynapseHomeServer(HomeServer): if listener["type"] == "http": self._listener_http(config, listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index a47283e520..ef17b158a5 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -87,8 +87,7 @@ class MediaRepositoryServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -107,9 +106,6 @@ class MediaRepositoryServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -129,11 +125,7 @@ class MediaRepositoryServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 57e097fa11..073f2c2489 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -121,8 +121,7 @@ class PusherServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -132,9 +131,6 @@ class PusherServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -154,11 +150,7 @@ class PusherServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 439daaa60a..4dfc2dc648 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -289,8 +289,7 @@ class SynchrotronServer(HomeServer): def _listen_http(self, listener_config): port = listener_config["port"] - bind_address = listener_config.get("bind_address", None) - bind_addresses = listener_config.get("bind_addresses", []) + bind_addresses = listener_config["bind_addresses"] site_tag = listener_config.get("tag", port) resources = {} for res in listener_config["resources"]: @@ -312,9 +311,6 @@ class SynchrotronServer(HomeServer): root_resource = create_resource_tree(resources, Resource()) - if bind_address is not None: - bind_addresses.append(bind_address) - for address in bind_addresses: reactor.listenTCP( port, @@ -334,11 +330,7 @@ class SynchrotronServer(HomeServer): if listener["type"] == "http": self._listen_http(listener) elif listener["type"] == "manhole": - bind_address = listener.get("bind_address", None) - bind_addresses = listener.get("bind_addresses", []) - - if bind_address is not None: - bind_addresses.append(bind_address) + bind_addresses = listener["bind_addresses"] for address in bind_addresses: reactor.listenTCP( diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 63e69a7e0c..77ded0ad25 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -22,7 +22,6 @@ import yaml from string import Template import os import signal -from synapse.util.debug import debug_deferreds DEFAULT_LOG_CONFIG = Template(""" @@ -71,8 +70,6 @@ class LoggingConfig(Config): self.verbosity = config.get("verbose", 0) self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - if config.get("full_twisted_stacktraces"): - debug_deferreds() def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") @@ -88,11 +85,6 @@ class LoggingConfig(Config): # A yaml python logging config file log_config: "%(log_config)s" - - # Stop twisted from discarding the stack traces of exceptions in - # deferreds by waiting a reactor tick before running a deferred's - # callbacks. - # full_twisted_stacktraces: true """ % locals() def read_arguments(self, args): diff --git a/synapse/config/server.py b/synapse/config/server.py index 5e6b2a68a7..1f9999d57a 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -42,6 +42,15 @@ class ServerConfig(Config): self.listeners = config.get("listeners", []) + for listener in self.listeners: + bind_address = listener.pop("bind_address", None) + bind_addresses = listener.setdefault("bind_addresses", []) + + if bind_address: + bind_addresses.append(bind_address) + elif not bind_addresses: + bind_addresses.append('') + self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) bind_port = config.get("bind_port") @@ -54,7 +63,7 @@ class ServerConfig(Config): self.listeners.append({ "port": bind_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": True, "type": "http", "resources": [ @@ -73,7 +82,7 @@ class ServerConfig(Config): if unsecure_port: self.listeners.append({ "port": unsecure_port, - "bind_address": bind_host, + "bind_addresses": [bind_host], "tls": False, "type": "http", "resources": [ @@ -92,7 +101,7 @@ class ServerConfig(Config): if manhole: self.listeners.append({ "port": manhole, - "bind_address": "127.0.0.1", + "bind_addresses": ["127.0.0.1"], "type": "manhole", }) @@ -100,7 +109,7 @@ class ServerConfig(Config): if metrics_port: self.listeners.append({ "port": metrics_port, - "bind_address": config.get("metrics_bind_host", "127.0.0.1"), + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], "tls": False, "type": "http", "resources": [ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8a76469b77..b2806555cf 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -232,11 +232,12 @@ class RoomMemberHandler(BaseHandler): errcode=Codes.BAD_STATE ) - same_content = content == old_state.content - same_membership = old_membership == effective_membership_state - same_sender = requester.user.to_string() == old_state.sender - if same_sender and same_membership and same_content: - defer.returnValue(old_state) + if old_state: + same_content = content == old_state.content + same_membership = old_membership == effective_membership_state + same_sender = requester.user.to_string() == old_state.sender + if same_sender and same_membership and same_content: + defer.returnValue(old_state) is_host_in_room = yield self._is_host_in_room(current_state_ids) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index b47bf1f92b..a27476bbad 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -52,7 +52,7 @@ def get_badge_count(store, user_id): def get_context_for_event(store, state_handler, ev, user_id): ctx = {} - room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) + room_state_ids = yield store.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 351170edbc..efa77b8c51 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -86,7 +86,11 @@ class HttpTransactionCache(object): pass # execute the function instead. deferred = fn(*args, **kwargs) - observable = ObservableDeferred(deferred) + + # We don't add an errback to the raw deferred, so we ask ObservableDeferred + # to swallow the error. This is fine as the error will still be reported + # to the observers. + observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe() diff --git a/synapse/state.py b/synapse/state.py index 15238cd00c..3a0a874c9e 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) -SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) +SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -79,6 +79,9 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + def __len__(self): + return len(self.state) + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -101,6 +104,7 @@ class StateHandler(object): clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, reset_expiry_on_get=True, ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5620a655eb..963ef999d5 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2821eb89c9..bde3b5cbbc 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -18,13 +18,29 @@ import ujson from twisted.internet import defer -from ._base import SQLBaseStore +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -class DeviceInboxStore(SQLBaseStore): +class DeviceInboxStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, hs): + super(DeviceInboxStore, self).__init__(hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, + self._background_drop_index_device_inbox, + ) @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, @@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) + txn.close() + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + defer.returnValue(1) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e46ae6502e..b357f22be7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 39 +SCHEMA_VERSION = 40 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..768e0a4451 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..b9fe1f0480 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- turn the pre-fill startup query into a index-only scan on postgresql. +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..7d34dd03bf 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -284,7 +284,7 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8dba61d49f..675bfd5feb 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,7 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) @@ -42,6 +42,25 @@ _CacheSentinel = object() CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) +class CacheEntry(object): + __slots__ = [ + "deferred", "sequence", "callbacks", "invalidated" + ] + + def __init__(self, deferred, sequence, callbacks): + self.deferred = deferred + self.sequence = sequence + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() + + class Cache(object): __slots__ = ( "cache", @@ -51,12 +70,16 @@ class Cache(object): "sequence", "thread", "metrics", + "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False): + def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): cache_type = TreeCache if tree else dict + self._pending_deferred_cache = cache_type() + self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type + max_size=max_entries, keylen=keylen, cache_type=cache_type, + size_callback=(lambda d: len(d.result)) if iterable else None, ) self.name = name @@ -76,7 +99,15 @@ class Cache(object): ) def get(self, key, default=_CacheSentinel, callback=None): - val = self.cache.get(key, _CacheSentinel, callback=callback) + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + if val.sequence == self.sequence: + val.callbacks.update(callbacks) + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: self.metrics.inc_hits() return val @@ -88,15 +119,39 @@ class Cache(object): else: return default - def update(self, sequence, key, value, callback=None): + def set(self, key, value, callback=None): + callbacks = [callback] if callback else [] self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value, callback=callback) + entry = CacheEntry( + deferred=value, + sequence=self.sequence, + callbacks=callbacks, + ) + + entry.callbacks.update(callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def shuffle(result): + if self.sequence == entry.sequence: + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, entry.deferred, entry.callbacks) + else: + entry.invalidate() + else: + entry.invalidate() + return result + + entry.deferred.addCallback(shuffle) def prefill(self, key, value, callback=None): - self.cache.set(key, value, callback=callback) + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): self.check_thread() @@ -108,6 +163,10 @@ class Cache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 + entry = self._pending_deferred_cache.pop(key, None) + if entry: + entry.invalidate() + self.cache.pop(key, None) def invalidate_many(self, key): @@ -119,6 +178,11 @@ class Cache(object): self.sequence += 1 self.cache.del_multi(key) + entry_dict = self._pending_deferred_cache.pop(key, None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + def invalidate_all(self): self.check_thread() self.sequence += 1 @@ -155,7 +219,7 @@ class CacheDescriptor(object): """ def __init__(self, orig, max_entries=1000, num_args=1, tree=False, - inlineCallbacks=False, cache_context=False): + inlineCallbacks=False, cache_context=False, iterable=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -169,6 +233,8 @@ class CacheDescriptor(object): self.num_args = num_args self.tree = tree + self.iterable = iterable + all_args = inspect.getargspec(orig) self.arg_names = all_args.args[1:num_args + 1] @@ -203,6 +269,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, + iterable=self.iterable, ) @functools.wraps(self.orig) @@ -243,11 +310,6 @@ class CacheDescriptor(object): return preserve_context_over_deferred(observer) except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - ret = defer.maybeDeferred( preserve_context_over_fn, self.function_to_call, @@ -261,7 +323,7 @@ class CacheDescriptor(object): ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - cache.update(sequence, cache_key, ret, callback=invalidate_callback) + cache.set(cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) @@ -359,7 +421,6 @@ class CacheListDescriptor(object): missing.append(arg) if missing: - sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -382,8 +443,8 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - cache.update( - sequence, tuple(key), observer, + cache.set( + tuple(key), observer, callback=invalidate_callback ) @@ -421,17 +482,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, tree=tree, cache_context=cache_context, + iterable=iterable, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, + iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -439,6 +503,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex tree=tree, inlineCallbacks=True, cache_context=cache_context, + iterable=iterable, ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b0ca1bb79d..cb6933c61c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,9 @@ import logging logger = logging.getLogger(__name__) -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): + def __len__(self): + return len(self.value) class DictionaryCache(object): @@ -32,7 +34,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, size_callback=len) self.name = name self.sequence = 0 diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 080388958f..2987c38a2d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -15,6 +15,7 @@ from synapse.util.caches import register_cache +from collections import OrderedDict import logging @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) class ExpiringCache(object): def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): + reset_expiry_on_get=False, iterable=False): """ Args: cache_name (str): Name of this cache, used for logging. @@ -36,6 +37,8 @@ class ExpiringCache(object): evicted based on time. reset_expiry_on_get (bool): If true, will reset the expiry time for an item on access. Defaults to False. + iterable (bool): If true, the size is calculated by summing the + sizes of all entries, rather than the number of entries. """ self._cache_name = cache_name @@ -47,9 +50,13 @@ class ExpiringCache(object): self._reset_expiry_on_get = reset_expiry_on_get - self._cache = {} + self._cache = OrderedDict() - self.metrics = register_cache(cache_name, self._cache) + self.metrics = register_cache(cache_name, self) + + self.iterable = iterable + + self._size_estimate = 0 def start(self): if not self._expiry_ms: @@ -65,15 +72,14 @@ class ExpiringCache(object): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) - # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: - sorted_entries = sorted( - self._cache.items(), - key=lambda item: item[1].time, - ) + if self.iterable: + self._size_estimate += len(value) - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) + # Evict if there are now too many items + while self._max_len and len(self) > self._max_len: + _key, value = self._cache.popitem(last=False) + if self.iterable: + self._size_estimate -= len(value.value) def __getitem__(self, key): try: @@ -99,7 +105,7 @@ class ExpiringCache(object): # zero expiry time means don't expire. This should never get called # since we have this check in start too. return - begin_length = len(self._cache) + begin_length = len(self) now = self._clock.time_msec() @@ -110,15 +116,20 @@ class ExpiringCache(object): keys_to_delete.add(key) for k in keys_to_delete: - self._cache.pop(k) + value = self._cache.pop(k) + if self.iterable: + self._size_estimate -= len(value.value) logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache) + self._cache_name, begin_length, len(self) ) def __len__(self): - return len(self._cache) + if self.iterable: + return self._size_estimate + else: + return len(self._cache) class _CacheEntry(object): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 9c4c679175..072f9a9d19 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,7 +49,7 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): cache = cache_type() self.cache = cache # Used for introspection. list_root = _Node(None, None, None, None) @@ -58,6 +58,12 @@ class LruCache(object): lock = threading.Lock() + def evict(): + while cache_len() > max_size: + todelete = list_root.prev_node + delete_node(todelete) + cache.pop(todelete.key, None) + def synchronized(f): @wraps(f) def inner(*args, **kwargs): @@ -66,6 +72,16 @@ class LruCache(object): return inner + cached_cache_len = [0] + if size_callback is not None: + def cache_len(): + return cached_cache_len[0] + else: + def cache_len(): + return len(cache) + + self.len = synchronized(cache_len) + def add_node(key, value, callbacks=set()): prev_node = list_root next_node = prev_node.next_node @@ -74,6 +90,9 @@ class LruCache(object): next_node.prev_node = node cache[key] = node + if size_callback: + cached_cache_len[0] += size_callback(node.value) + def move_node_to_front(node): prev_node = node.prev_node next_node = node.next_node @@ -92,23 +111,25 @@ class LruCache(object): prev_node.next_node = next_node next_node.prev_node = prev_node + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + for cb in node.callbacks: cb() node.callbacks.clear() @synchronized - def cache_get(key, default=None, callback=None): + def cache_get(key, default=None, callbacks=[]): node = cache.get(key, None) if node is not None: move_node_to_front(node) - if callback: - node.callbacks.add(callback) + node.callbacks.update(callbacks) return node.value else: return default @synchronized - def cache_set(key, value, callback=None): + def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: if value != node.value: @@ -116,21 +137,18 @@ class LruCache(object): cb() node.callbacks.clear() - if callback: - node.callbacks.add(callback) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) + + node.callbacks.update(callbacks) move_node_to_front(node) node.value = value else: - if callback: - callbacks = set([callback]) - else: - callbacks = set() - add_node(key, value, callbacks) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + add_node(key, value, set(callbacks)) + + evict() @synchronized def cache_set_default(key, value): @@ -139,10 +157,7 @@ class LruCache(object): return node.value else: add_node(key, value) - if len(cache) > max_size: - todelete = list_root.prev_node - delete_node(todelete) - cache.pop(todelete.key, None) + evict() return value @synchronized @@ -176,10 +191,6 @@ class LruCache(object): cache.clear() @synchronized - def cache_len(): - return len(cache) - - @synchronized def cache_contains(key): return key in cache @@ -190,7 +201,7 @@ class LruCache(object): self.pop = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi - self.len = cache_len + self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index c31585aea3..fcc341a6b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -65,12 +65,27 @@ class TreeCache(object): return popped def values(self): - return [e.value for e in self.root.values()] + return list(iterate_tree_cache_entry(self.root)) def __len__(self): return self.size +def iterate_tree_cache_entry(d): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + """ + if isinstance(d, dict): + for value_d in d.itervalues(): + for value in iterate_tree_cache_entry(value_d): + yield value + else: + if isinstance(d, _Entry): + yield d.value + else: + yield d + + class _Entry(object): __slots__ = ["value"] diff --git a/synapse/util/debug.py b/synapse/util/debug.py deleted file mode 100644 index dc49162e6a..0000000000 --- a/synapse/util/debug.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer, reactor -from functools import wraps -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext - - -def debug_deferreds(): - """Cause all deferreds to wait for a reactor tick before running their - callbacks. This increases the chance of getting a stack trace out of - a defer.inlineCallback since the code waiting on the deferred will get - a chance to add an errback before the deferred runs.""" - - # Helper method for retrieving and restoring the current logging context - # around a callback. - def with_logging_context(fn): - context = LoggingContext.current_context() - - def restore_context_callback(x): - with PreserveLoggingContext(context): - return fn(x) - - return restore_context_callback - - # We are going to modify the __init__ method of defer.Deferred so we - # need to get a copy of the old method so we can still call it. - old__init__ = defer.Deferred.__init__ - - # We need to create a deferred to bounce the callbacks through the reactor - # but we don't want to add a callback when we create that deferred so we - # we create a new type of deferred that uses the old __init__ method. - # This is safe as long as the old __init__ method doesn't invoke an - # __init__ using super. - class Bouncer(defer.Deferred): - __init__ = old__init__ - - # We'll add this as a callback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the callbacks of the - # original deferred. - def bounce_callback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.callback), x) - return bouncer - - # We'll add this as an errback to all Deferreds. Twisted will wait until - # the bouncer deferred resolves before calling the errbacks of the - # original deferred. - def bounce_errback(x): - bouncer = Bouncer() - reactor.callLater(0, with_logging_context(bouncer.errback), x) - return bouncer - - @wraps(old__init__) - def new__init__(self, *args, **kargs): - old__init__(self, *args, **kargs) - self.addCallbacks(bounce_callback, bounce_errback) - - defer.Deferred.__init__ = new__init__ diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index ab6095564a..8361dd8cee 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=2) + @cached(max_entries=20) # HACK: This makes it 2 due to cache factor def func(self, key): callcount[0] += 1 return key @@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 2) self.assertEquals(callcount2[0], 2) + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + yield a.func("foo3") self.assertEquals(callcount[0], 3) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py new file mode 100644 index 0000000000..31d24adb8b --- /dev/null +++ b/tests/util/test_expiring_cache.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.util.caches.expiringcache import ExpiringCache + +from tests.utils import MockClock + + +class ExpiringCacheTestCase(unittest.TestCase): + + def test_get_set(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=1) + + cache["key"] = "value" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache["key"], "value") + + def test_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=2) + + cache["key"] = "value" + cache["key2"] = "value2" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache.get("key2"), "value2") + + cache["key3"] = "value3" + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), "value2") + self.assertEquals(cache.get("key3"), "value3") + + def test_iterable_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, max_len=5, iterable=True) + + cache["key"] = [1] + cache["key2"] = [2, 3] + cache["key3"] = [4, 5] + + self.assertEquals(cache.get("key"), [1]) + self.assertEquals(cache.get("key2"), [2, 3]) + self.assertEquals(cache.get("key3"), [4, 5]) + + cache["key4"] = [6, 7] + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache.get("key3"), [4, 5]) + self.assertEquals(cache.get("key4"), [6, 7]) + + def test_time_eviction(self): + clock = MockClock() + cache = ExpiringCache("test", clock, expiry_ms=1000) + cache.start() + + cache["key"] = 1 + clock.advance_time(0.5) + cache["key2"] = 2 + + self.assertEquals(cache.get("key"), 1) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(0.9) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), 2) + + clock.advance_time(1) + self.assertEquals(cache.get("key"), None) + self.assertEquals(cache.get("key2"), None) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 1eba5b535e..dfb78cb8bd 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.get("key", "value") @@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): cache.set("key", "value") self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) - cache.get("key", callback=m) + cache.get("key", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value2") @@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.set("key", "value") @@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m = Mock() cache = LruCache(1) - cache.set("key", "value", m) + cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) cache.pop("key") @@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m4 = Mock() cache = LruCache(4, 2, cache_type=TreeCache) - cache.set(("a", "1"), "value", m1) - cache.set(("a", "2"), "value", m2) - cache.set(("b", "1"), "value", m3) - cache.set(("b", "2"), "value", m4) + cache.set(("a", "1"), "value", callbacks=[m1]) + cache.set(("a", "2"), "value", callbacks=[m2]) + cache.set(("b", "1"), "value", callbacks=[m3]) + cache.set(("b", "2"), "value", callbacks=[m4]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m2 = Mock() cache = LruCache(5) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) @@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase): m3 = Mock(name="m3") cache = LruCache(2) - cache.set("key1", "value", m1) - cache.set("key2", "value", m2) + cache.set("key1", "value", callbacks=[m1]) + cache.set("key2", "value", callbacks=[m2]) self.assertEquals(m1.call_count, 0) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key3", "value", m3) + cache.set("key3", "value", callbacks=[m3]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) @@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 0) - cache.set("key1", "value", m1) + cache.set("key1", "value", callbacks=[m1]) self.assertEquals(m1.call_count, 1) self.assertEquals(m2.call_count, 0) self.assertEquals(m3.call_count, 1) + + +class LruCacheSizedTestCase(unittest.TestCase): + + def test_evict(self): + cache = LruCache(5, size_callback=len) + cache["key1"] = [0] + cache["key2"] = [1, 2] + cache["key3"] = [3] + cache["key4"] = [4] + + self.assertEquals(cache["key1"], [0]) + self.assertEquals(cache["key2"], [1, 2]) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(len(cache), 5) + + cache["key5"] = [5, 6] + + self.assertEquals(len(cache), 4) + self.assertEquals(cache.get("key1"), None) + self.assertEquals(cache.get("key2"), None) + self.assertEquals(cache["key3"], [3]) + self.assertEquals(cache["key4"], [4]) + self.assertEquals(cache["key5"], [5, 6]) |