diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 71d5d92500..c6ce65b4cc 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip)
try:
- last_seen = self.client_ip_last_seen.get(*key)
+ last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
@@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None)
- self.client_ip_last_seen.prefill(*key + (now,))
+ self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine)
+ elif ext == ".pyc":
+ # Sometimes .pyc files turn up anyway even though we've
+ # disabled their generation; e.g. from distribution package
+ # installers. Silently skip it
+ pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8f812f0fd7..73eea157a4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
+import inspect
import sys
import time
import threading
@@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
)
+_CacheSentinel = object()
+
+
class Cache(object):
- def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
@@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, *keyargs):
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if keyargs in self.cache:
+ def get(self, key, default=_CacheSentinel):
+ val = self.cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
- return self.cache[keyargs]
+ return val
cache_counter.inc_misses(self.name)
- raise KeyError()
- def update(self, sequence, *args):
+ if default is _CacheSentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ self.prefill(key, value)
+ def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
- self.cache[keyargs] = value
+ self.cache[key] = value
- def invalidate(self, *keyargs):
+ def invalidate(self, key):
self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ if not isinstance(key, tuple):
+ raise TypeError(
+ "The cache key must be a tuple not %r" % (type(key),)
+ )
+
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(keyargs, None)
+ self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@@ -130,6 +134,9 @@ class Cache(object):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
+ This caches deferreds, rather than the results themselves. Deferreds that
+ fail are removed from the cache.
+
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@@ -141,58 +148,92 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+ inlineCallbacks=False):
self.orig = orig
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
- def __get__(self, obj, objtype=None):
- cache = Cache(
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
+ self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
+ def __get__(self, obj, objtype=None):
+
@functools.wraps(self.orig)
- @defer.inlineCallbacks
- def wrapped(*keyargs):
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result = cache.get(*keyargs[:self.num_args])
+ cached_result_d = self.cache.get(cache_key)
+
+ observer = cached_result_d.observe()
if DEBUG_CACHES:
- actual_result = yield self.orig(obj, *keyargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, keyargs,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
+ @defer.inlineCallbacks
+ def check_result(cached_result):
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ self.orig.__name__, cache_key,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ observer.addCallback(check_result)
+
+ return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = cache.sequence
+ sequence = self.cache.sequence
+
+ ret = defer.maybeDeferred(
+ self.function_to_call,
+ obj, *args, **kwargs
+ )
+
+ def onErr(f):
+ self.cache.invalidate(cache_key)
+ return f
- ret = yield self.orig(obj, *keyargs)
+ ret.addErrback(onErr)
- cache.update(sequence, *keyargs[:self.num_args] + (ret,))
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, cache_key, ret)
- defer.returnValue(ret)
+ return ret.observe()
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-def cached(max_entries=1000, num_args=1, lru=False):
+def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
)
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru,
+ inlineCallbacks=True,
+ )
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 2b2bdf8615..f3947bbe89 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
},
desc="create_room_alias_association",
)
- self.get_aliases_for_room.invalidate(room_id)
+ self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
- self.get_aliases_for_room.invalidate(room_id)
+ self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 45b86c94e8..910b6598a7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
for room_id in events_by_room:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate, room_id
+ self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_backfill_events(self, room_id, event_list, limit):
@@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
- txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
+ txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ed7ea38804..5b64918024 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
- txn.call_after(self.get_users_in_room.invalidate, event.room_id)
- txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+ txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
@@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
if not context.rejected:
txn.call_after(
self.get_current_state_for_key.invalidate,
- event.room_id, event.type, event.state_key
- )
+ (event.room_id, event.type, event.state_key,)
+ )
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
- event.room_id
+ (event.room_id,)
)
self._simple_upsert_txn(
@@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
- self._get_event_cache.invalidate(event_id, check_redacted,
- get_prev_content)
+ self._get_event_cache.invalidate(
+ (event_id, check_redacted, get_prev_content)
+ )
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events:
try:
ret = self._get_event_cache.get(
- event_id, check_redacted, get_prev_content
+ (event_id, check_redacted, get_prev_content,)
)
if allow_rejected or not ret.rejected_reason:
@@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
- ev.event_id, check_redacted, get_prev_content, ev
+ (ev.event_id, check_redacted, get_prev_content), ev
)
defer.returnValue(ev)
@@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
- ev.event_id, check_redacted, get_prev_content, ev
+ (ev.event_id, check_redacted, get_prev_content), ev
)
return ev
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 940a5f7e08..49b8e37cfd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore, cached
+from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
@@ -132,7 +131,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
- self.get_all_server_verify_keys.invalidate(server_name)
+ self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index fefcf6bce0..576cf670cc 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
- self.get_presence_list_accepted.invalidate(observer_localpart)
+ self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid},
desc="del_presence_list",
)
- self.get_presence_list_accepted.invalidate(observer_localpart)
+ self.get_presence_list_accepted.invalidate((observer_localpart,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 4cac118d17..9b88ca7b39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
import logging
@@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
@@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
@@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule",
)
- self.get_push_rules_for_user.invalidate(user_name)
- self.get_push_rules_enabled_for_user.invalidate(user_name)
+ self.get_push_rules_for_user.invalidate((user_name,))
+ self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
@@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id},
)
txn.call_after(
- self.get_push_rules_for_user.invalidate, user_name
+ self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, user_name
+ self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 7a6af98d98..b79d6683ca 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
- @cached
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 90e2606be2..4eaa088b36 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id
)
for r in rows:
- self.get_user_by_token.invalidate(r)
+ self.get_user_by_token.invalidate((r,))
@cached()
def get_user_by_token(self, token):
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 4612a8aa83..dd5bc2c8fb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
import collections
import logging
@@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
}
)
- @cached()
- @defer.inlineCallbacks
+ @cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
sql = (
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4db07f6fb4..9f14f38f24 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
)
for event in events:
- txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
- txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
- txn.call_after(self.get_users_in_room.invalidate, event.room_id)
+ txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
+ txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
+ txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -78,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None
)
- @cached()
+ @cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@@ -154,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
- @cached()
+ @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
return self.runInteraction(
"get_joined_hosts_for_room",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 47bec65497..7ce51b9bdc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
defer.returnValue(dict(state_list))
- @cached(num_args=1)
def _fetch_events_for_group(self, key, events):
return self._get_events(
events, get_prev_content=False
@@ -189,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
- @cached(num_args=3)
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
|