summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/__init__.py4
-rw-r--r--synapse/storage/_base.py61
-rw-r--r--synapse/storage/directory.py4
-rw-r--r--synapse/storage/event_federation.py4
-rw-r--r--synapse/storage/events.py21
-rw-r--r--synapse/storage/keys.py2
-rw-r--r--synapse/storage/presence.py4
-rw-r--r--synapse/storage/push_rule.py16
-rw-r--r--synapse/storage/registration.py2
-rw-r--r--synapse/storage/roommember.py6
-rw-r--r--tests/storage/test__base.py12
11 files changed, 69 insertions, 67 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 71d5d92500..1a6a8a3762 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
         key = (user.to_string(), access_token, device_id, ip)
 
         try:
-            last_seen = self.client_ip_last_seen.get(*key)
+            last_seen = self.client_ip_last_seen.get(key)
         except KeyError:
             last_seen = None
 
@@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
             defer.returnValue(None)
 
-        self.client_ip_last_seen.prefill(*key + (now,))
+        self.client_ip_last_seen.prefill(key, now)
 
         # It's safe not to lock here: a) no unique constraint,
         # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0872a438f1..e07cf3b58a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,6 +57,9 @@ cache_counter = metrics.register_cache(
 )
 
 
+_CacheSentinel = object()
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@@ -83,41 +86,38 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, *keyargs):
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
-
-        if keyargs in self.cache:
+    def get(self, keyargs, default=_CacheSentinel):
+        val = self.cache.get(keyargs, _CacheSentinel)
+        if val is not _CacheSentinel:
             cache_counter.inc_hits(self.name)
-            return self.cache[keyargs]
+            return val
 
         cache_counter.inc_misses(self.name)
-        raise KeyError()
 
-    def update(self, sequence, *args):
+        if default is _CacheSentinel:
+            raise KeyError()
+        else:
+            return default
+
+    def update(self, sequence, keyargs, value):
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(*args)
-
-    def prefill(self, *args):  # because I can't  *keyargs, value
-        keyargs = args[:-1]
-        value = args[-1]
-
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+            self.prefill(keyargs, value)
 
+    def prefill(self, keyargs, value):
         if self.max_entries is not None:
             while len(self.cache) >= self.max_entries:
                 self.cache.popitem(last=False)
 
         self.cache[keyargs] = value
 
-    def invalidate(self, *keyargs):
+    def invalidate(self, keyargs):
         self.check_thread()
-        if len(keyargs) != self.keylen:
-            raise ValueError("Expected a key to have %d items", self.keylen)
+        if not isinstance(keyargs, tuple):
+            raise ValueError("keyargs must be a tuple.")
+
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
@@ -168,20 +168,21 @@ class CacheDescriptor(object):
                 % (orig.__name__,)
             )
 
-    def __get__(self, obj, objtype=None):
-        cache = Cache(
+        self.cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             lru=self.lru,
         )
 
+    def __get__(self, obj, objtype=None):
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
+            keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result_d = cache.get(*keyargs)
+                cached_result_d = self.cache.get(keyargs)
 
                 if DEBUG_CACHES:
 
@@ -202,7 +203,7 @@ class CacheDescriptor(object):
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = cache.sequence
+                sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
                     self.function_to_call,
@@ -210,19 +211,19 @@ class CacheDescriptor(object):
                 )
 
                 def onErr(f):
-                    cache.invalidate(*keyargs)
+                    self.cache.invalidate(keyargs)
                     return f
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=False)
-                cache.update(sequence, *(keyargs + [ret]))
+                ret = ObservableDeferred(ret, consumeErrors=True)
+                self.cache.update(sequence, keyargs, ret)
 
                 return ret.observe()
 
-        wrapped.invalidate = cache.invalidate
-        wrapped.invalidate_all = cache.invalidate_all
-        wrapped.prefill = cache.prefill
+        wrapped.invalidate = self.cache.invalidate
+        wrapped.invalidate_all = self.cache.invalidate_all
+        wrapped.prefill = self.cache.prefill
 
         obj.__dict__[self.orig.__name__] = wrapped
 
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 2b2bdf8615..f3947bbe89 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
                 },
                 desc="create_room_alias_association",
             )
-        self.get_aliases_for_room.invalidate(room_id)
+        self.get_aliases_for_room.invalidate((room_id,))
 
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
@@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
             room_alias,
         )
 
-        self.get_aliases_for_room.invalidate(room_id)
+        self.get_aliases_for_room.invalidate((room_id,))
         defer.returnValue(room_id)
 
     def _delete_room_alias_txn(self, txn, room_alias):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 45b86c94e8..910b6598a7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
 
         for room_id in events_by_room:
             txn.call_after(
-                self.get_latest_event_ids_in_room.invalidate, room_id
+                self.get_latest_event_ids_in_room.invalidate, (room_id,)
             )
 
     def get_backfill_events(self, room_id, event_list, limit):
@@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
         query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
 
         txn.execute(query, (room_id,))
-        txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
+        txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ed7ea38804..5b64918024 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
         if current_state:
             txn.call_after(self.get_current_state_for_key.invalidate_all)
             txn.call_after(self.get_rooms_for_user.invalidate_all)
-            txn.call_after(self.get_users_in_room.invalidate, event.room_id)
-            txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+            txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
             txn.call_after(self.get_room_name_and_aliases, event.room_id)
 
             self._simple_delete_txn(
@@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
                 if not context.rejected:
                     txn.call_after(
                         self.get_current_state_for_key.invalidate,
-                        event.room_id, event.type, event.state_key
-                        )
+                        (event.room_id, event.type, event.state_key,)
+                    )
 
                     if event.type in [EventTypes.Name, EventTypes.Aliases]:
                         txn.call_after(
                             self.get_room_name_and_aliases.invalidate,
-                            event.room_id
+                            (event.room_id,)
                         )
 
                     self._simple_upsert_txn(
@@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
     def _invalidate_get_event_cache(self, event_id):
         for check_redacted in (False, True):
             for get_prev_content in (False, True):
-                self._get_event_cache.invalidate(event_id, check_redacted,
-                                                 get_prev_content)
+                self._get_event_cache.invalidate(
+                    (event_id, check_redacted, get_prev_content)
+                )
 
     def _get_event_txn(self, txn, event_id, check_redacted=True,
                        get_prev_content=False, allow_rejected=False):
@@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
         for event_id in events:
             try:
                 ret = self._get_event_cache.get(
-                    event_id, check_redacted, get_prev_content
+                    (event_id, check_redacted, get_prev_content,)
                 )
 
                 if allow_rejected or not ret.rejected_reason:
@@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
                 ev.unsigned["prev_content"] = prev.get_dict()["content"]
 
         self._get_event_cache.prefill(
-            ev.event_id, check_redacted, get_prev_content, ev
+            (ev.event_id, check_redacted, get_prev_content), ev
         )
 
         defer.returnValue(ev)
@@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
                 ev.unsigned["prev_content"] = prev.get_dict()["content"]
 
         self._get_event_cache.prefill(
-            ev.event_id, check_redacted, get_prev_content, ev
+            (ev.event_id, check_redacted, get_prev_content), ev
         )
 
         return ev
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index e3f98f0cde..49b8e37cfd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore):
             desc="store_server_verify_key",
         )
 
-        self.get_all_server_verify_keys.invalidate(server_name)
+        self.get_all_server_verify_keys.invalidate((server_name,))
 
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index fefcf6bce0..576cf670cc 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
             updatevalues={"accepted": True},
             desc="set_presence_list_accepted",
         )
-        self.get_presence_list_accepted.invalidate(observer_localpart)
+        self.get_presence_list_accepted.invalidate((observer_localpart,))
         defer.returnValue(result)
 
     def get_presence_list(self, observer_localpart, accepted=None):
@@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
                        "observed_user_id": observed_userid},
             desc="del_presence_list",
         )
-        self.get_presence_list_accepted.invalidate(observer_localpart)
+        self.get_presence_list_accepted.invalidate((observer_localpart,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index a220f3632e..9b88ca7b39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
             txn.execute(sql, (user_name, priority_class, new_rule_priority))
 
         txn.call_after(
-            self.get_push_rules_for_user.invalidate, user_name
+            self.get_push_rules_for_user.invalidate, (user_name,)
         )
 
         txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, user_name
+            self.get_push_rules_enabled_for_user.invalidate, (user_name,)
         )
 
         self._simple_insert_txn(
@@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
         new_rule['priority'] = new_prio
 
         txn.call_after(
-            self.get_push_rules_for_user.invalidate, user_name
+            self.get_push_rules_for_user.invalidate, (user_name,)
         )
         txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, user_name
+            self.get_push_rules_enabled_for_user.invalidate, (user_name,)
         )
 
         self._simple_insert_txn(
@@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
             desc="delete_push_rule",
         )
 
-        self.get_push_rules_for_user.invalidate(user_name)
-        self.get_push_rules_enabled_for_user.invalidate(user_name)
+        self.get_push_rules_for_user.invalidate((user_name,))
+        self.get_push_rules_enabled_for_user.invalidate((user_name,))
 
     @defer.inlineCallbacks
     def set_push_rule_enabled(self, user_name, rule_id, enabled):
@@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
             {'id': new_id},
         )
         txn.call_after(
-            self.get_push_rules_for_user.invalidate, user_name
+            self.get_push_rules_for_user.invalidate, (user_name,)
         )
         txn.call_after(
-            self.get_push_rules_enabled_for_user.invalidate, user_name
+            self.get_push_rules_enabled_for_user.invalidate, (user_name,)
         )
 
 
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 90e2606be2..4eaa088b36 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
             user_id
         )
         for r in rows:
-            self.get_user_by_token.invalidate(r)
+            self.get_user_by_token.invalidate((r,))
 
     @cached()
     def get_user_by_token(self, token):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 55dd3f6cfb..9f14f38f24 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
         )
 
         for event in events:
-            txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
-            txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
-            txn.call_after(self.get_users_in_room.invalidate, event.room_id)
+            txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
+            txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
+            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
 
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8fa305d18a..abee2f631d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
         self.assertEquals(self.cache.get("foo"), 123)
 
     def test_invalidate(self):
-        self.cache.prefill("foo", 123)
-        self.cache.invalidate("foo")
+        self.cache.prefill(("foo",), 123)
+        self.cache.invalidate(("foo",))
 
         failed = False
         try:
-            self.cache.get("foo")
+            self.cache.get(("foo",))
         except KeyError:
             failed = True
 
@@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(callcount[0], 1)
 
-        a.func.invalidate("foo")
+        a.func.invalidate(("foo",))
 
         yield a.func("foo")
 
@@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             def func(self, key):
                 return key
 
-        A().func.invalidate("what")
+        A().func.invalidate(("what",))
 
     @defer.inlineCallbacks
     def test_max_entries(self):
@@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         a = A()
 
-        a.func.prefill("foo", ObservableDeferred(d))
+        a.func.prefill(("foo",), ObservableDeferred(d))
 
         self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)