diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index be9934c66f..84f222b3db 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -35,6 +35,56 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
+# TODO(paul):
+# * more generic key management
+# * export monitoring stats
+# * consider other eviction strategies - LRU?
+def cached(max_entries=1000):
+ """ A method decorator that applies a memoizing cache around the function.
+
+ The function is presumed to take one additional argument, which is used as
+ the key for the cache. Cache hits are served directly from the cache;
+ misses use the function body to generate the value.
+
+ The wrapped function has an additional member, a callable called
+ "invalidate". This can be used to remove individual entries from the cache.
+
+ The wrapped function has another additional callable, called "prefill",
+ which can be used to insert values into the cache specifically, without
+ calling the calculation function.
+ """
+ def wrap(orig):
+ cache = {}
+
+ def prefill(key, value):
+ while len(cache) > max_entries:
+ # TODO(paul): This feels too biased. However, a random index
+ # would be a bit inefficient, walking the list of keys just
+ # to ignore most of them?
+ del cache[cache.keys()[0]]
+
+ cache[key] = value
+
+ @defer.inlineCallbacks
+ def wrapped(self, key):
+ if key in cache:
+ defer.returnValue(cache[key])
+
+ ret = yield orig(self, key)
+ prefill(key, ret)
+ defer.returnValue(ret)
+
+ def invalidate(key):
+ if key in cache:
+ del cache[key]
+
+ wrapped.invalidate = invalidate
+ wrapped.prefill = prefill
+ return wrapped
+
+ return wrap
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
@@ -586,8 +636,9 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
update_counter = self._get_event_counters.update
+ cache = self._get_event_cache.setdefault(event_id, {})
+
try:
- cache = self._get_event_cache.setdefault(event_id, {})
# Separate cache entries for each way to invoke _get_event_txn
return cache[(check_redacted, get_prev_content, allow_rejected)]
except KeyError:
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index d941b1f387..dc3666efd4 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -23,23 +23,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-class ApplicationServiceCache(object):
- """Caches ApplicationServices and provides utility functions on top.
-
- This class is designed to be invoked on incoming events in order to avoid
- hammering the database every time to extract a list of application service
- regexes.
- """
-
- def __init__(self):
- self.services = []
-
-
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
- self.cache = ApplicationServiceCache()
+ self.services_cache = []
self.cache_defer = self._populate_cache()
@defer.inlineCallbacks
@@ -56,7 +44,7 @@ class ApplicationServiceStore(SQLBaseStore):
token,
)
# update cache TODO: Should this be in the txn?
- for service in self.cache.services:
+ for service in self.services_cache:
if service.token == token:
service.url = None
service.namespaces = None
@@ -110,13 +98,13 @@ class ApplicationServiceStore(SQLBaseStore):
)
# update cache TODO: Should this be in the txn?
- for (index, cache_service) in enumerate(self.cache.services):
+ for (index, cache_service) in enumerate(self.services_cache):
if service.token == cache_service.token:
- self.cache.services[index] = service
+ self.services_cache[index] = service
logger.info("Updated: %s", service)
return
# new entry
- self.cache.services.append(service)
+ self.services_cache.append(service)
logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service):
@@ -160,7 +148,7 @@ class ApplicationServiceStore(SQLBaseStore):
@defer.inlineCallbacks
def get_app_services(self):
yield self.cache_defer # make sure the cache is ready
- defer.returnValue(self.cache.services)
+ defer.returnValue(self.services_cache)
@defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True):
@@ -176,7 +164,7 @@ class ApplicationServiceStore(SQLBaseStore):
yield self.cache_defer # make sure the cache is ready
if from_cache:
- for service in self.cache.services:
+ for service in self.services_cache:
if service.token == token:
defer.returnValue(service)
return
@@ -235,7 +223,7 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO get last successful txn id f.e. service
for service in services.values():
logger.info("Found application service: %s", service)
- self.cache.services.append(ApplicationService(
+ self.services_cache.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9bf608bc90..58aa376c20 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.constants import Membership
from synapse.types import UserID
@@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def __init__(self, *args, **kw):
- super(RoomMemberStore, self).__init__(*args, **kw)
-
- self._user_rooms_cache = {}
-
def _store_room_member_txn(self, txn, event):
"""Store a room member in the database.
"""
@@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (event.room_id, domain))
- self.invalidate_rooms_for_user(target_user_id)
+ self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
@@ -247,33 +242,12 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows)
return results
- # TODO(paul): Create a nice @cached decorator to do this
- # @cached
- # def get_foo(...)
- # ...
- # invalidate_foo = get_foo.invalidator
-
- @defer.inlineCallbacks
+ @cached()
def get_rooms_for_user(self, user_id):
- # TODO(paul): put some performance counters in here so we can easily
- # track what impact this cache is having
- if user_id in self._user_rooms_cache:
- defer.returnValue(self._user_rooms_cache[user_id])
-
- rooms = yield self.get_rooms_for_user_where_membership_is(
+ return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
- # TODO(paul): Consider applying a maximum size; just evict things at
- # random, or consider LRU?
-
- self._user_rooms_cache[user_id] = rooms
- defer.returnValue(rooms)
-
- def invalidate_rooms_for_user(self, user_id):
- if user_id in self._user_rooms_cache:
- del self._user_rooms_cache[user_id]
-
@defer.inlineCallbacks
def user_rooms_intersect(self, user_id_list):
""" Checks whether all the users whose IDs are given in a list share a
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index e06ef35690..6cac8d01ac 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, Table
+from ._base import SQLBaseStore, Table, cached
from collections import namedtuple
@@ -28,10 +28,6 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- # a write-through cache of DestinationsTable.EntryType indexed by
- # destination string
- destination_retry_cache = {}
-
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
@@ -211,6 +207,7 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
+ @cached()
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -221,9 +218,6 @@ class TransactionStore(SQLBaseStore):
None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme
"""
- if destination in self.destination_retry_cache:
- return defer.succeed(self.destination_retry_cache[destination])
-
return self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
@@ -250,7 +244,9 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms
"""
- self.destination_retry_cache[destination] = (
+ # As this is the new value, we might as well prefill the cache
+ self.get_destination_retry_timings.prefill(
+ destination,
DestinationsTable.EntryType(
destination,
retry_last_ts,
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
new file mode 100644
index 0000000000..fb306cb784
--- /dev/null
+++ b/tests/storage/test__base.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.storage._base import cached
+
+
+class CacheDecoratorTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_passthrough(self):
+ @cached()
+ def func(self, key):
+ return key
+
+ self.assertEquals((yield func(self, "foo")), "foo")
+ self.assertEquals((yield func(self, "bar")), "bar")
+
+ @defer.inlineCallbacks
+ def test_hit(self):
+ callcount = [0]
+
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ yield func(self, "foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ self.assertEquals((yield func(self, "foo")), "foo")
+ self.assertEquals(callcount[0], 1)
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ callcount = [0]
+
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ yield func(self, "foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ func.invalidate("foo")
+
+ yield func(self, "foo")
+
+ self.assertEquals(callcount[0], 2)
+
+ @defer.inlineCallbacks
+ def test_max_entries(self):
+ callcount = [0]
+
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ for k in range(0,12):
+ yield func(self, k)
+
+ self.assertEquals(callcount[0], 12)
+
+ # There must have been at least 2 evictions, meaning if we calculate
+ # all 12 values again, we must get called at least 2 more times
+ for k in range(0,12):
+ yield func(self, k)
+
+ self.assertTrue(callcount[0] >= 14,
+ msg="Expected callcount >= 14, got %d" % (callcount[0]))
+
+ @defer.inlineCallbacks
+ def test_prefill(self):
+ callcount = [0]
+
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ func.prefill("foo", 123)
+
+ self.assertEquals((yield func(self, "foo")), 123)
+ self.assertEquals(callcount[0], 0)
|