diff --git a/synapse/storage/util/caches.py b/synapse/storage/util/caches.py
new file mode 100644
index 0000000000..0877cc79f6
--- /dev/null
+++ b/synapse/storage/util/caches.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.lrucache import LruCache
+from collections import namedtuple
+import threading
+
+
+DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
+
+
+class DictionaryCache(object):
+
+ def __init__(self, name, max_entries=1000):
+ self.cache = LruCache(max_size=max_entries)
+
+ self.name = name
+ self.sequence = 0
+ self.thread = None
+ # caches_by_name[name] = self.cache
+
+ class Sentinel(object):
+ __slots__ = []
+
+ self.sentinel = Sentinel()
+
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
+ def get(self, key, dict_keys=None):
+ entry = self.cache.get(key, self.sentinel)
+ if entry is not self.sentinel:
+ # cache_counter.inc_hits(self.name)
+
+ if dict_keys is None:
+ return DictionaryEntry(entry.full, dict(entry.value))
+ else:
+ return DictionaryEntry(entry.full, {
+ k: entry.value[k]
+ for k in dict_keys
+ if k in entry.value
+ })
+
+ # cache_counter.inc_misses(self.name)
+ return DictionaryEntry(False, {})
+
+ def invalidate(self, key):
+ self.check_thread()
+
+ # Increment the sequence number so that any SELECT statements that
+ # raced with the INSERT don't update the cache (SYN-369)
+ self.sequence += 1
+ self.cache.pop(key, None)
+
+ def invalidate_all(self):
+ self.check_thread()
+ self.sequence += 1
+ self.cache.clear()
+
+ def update(self, sequence, key, value, full=False):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ if full:
+ self._insert(key, value)
+ else:
+ self._update_or_insert(key, value)
+
+ def _update_or_insert(self, key, value):
+ entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+ entry.value.update(value)
+
+ def _insert(self, key, value):
+ self.cache[key] = DictionaryEntry(True, value)
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
new file mode 100644
index 0000000000..8cb9be6581
--- /dev/null
+++ b/tests/util/test_dict_cache.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+from tests import unittest
+
+from synapse.storage.util.caches import DictionaryCache
+
+
+class DictCacheTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.cache = DictionaryCache("foobar")
+
+ def test_simple_cache_hit_full(self):
+ key = "test_simple_cache_hit_full"
+
+ v = self.cache.get(key)
+ self.assertEqual((False, {}), v)
+
+ seq = self.cache.sequence
+ test_value = {"test": "test_simple_cache_hit_full"}
+ self.cache.update(seq, key, test_value, full=True)
+
+ c = self.cache.get(key)
+ self.assertEqual(test_value, c.value)
+
+ def test_simple_cache_hit_partial(self):
+ key = "test_simple_cache_hit_partial"
+
+ seq = self.cache.sequence
+ test_value = {
+ "test": "test_simple_cache_hit_partial"
+ }
+ self.cache.update(seq, key, test_value, full=True)
+
+ c = self.cache.get(key, ["test"])
+ self.assertEqual(test_value, c.value)
+
+ def test_simple_cache_miss_partial(self):
+ key = "test_simple_cache_miss_partial"
+
+ seq = self.cache.sequence
+ test_value = {
+ "test": "test_simple_cache_miss_partial"
+ }
+ self.cache.update(seq, key, test_value, full=True)
+
+ c = self.cache.get(key, ["test2"])
+ self.assertEqual({}, c.value)
+
+ def test_simple_cache_hit_miss_partial(self):
+ key = "test_simple_cache_hit_miss_partial"
+
+ seq = self.cache.sequence
+ test_value = {
+ "test": "test_simple_cache_hit_miss_partial",
+ "test2": "test_simple_cache_hit_miss_partial2",
+ "test3": "test_simple_cache_hit_miss_partial3",
+ }
+ self.cache.update(seq, key, test_value, full=True)
+
+ c = self.cache.get(key, ["test2"])
+ self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
+
+ def test_multi_insert(self):
+ key = "test_simple_cache_hit_miss_partial"
+
+ seq = self.cache.sequence
+ test_value_1 = {
+ "test": "test_simple_cache_hit_miss_partial",
+ }
+ self.cache.update(seq, key, test_value_1, full=False)
+
+ seq = self.cache.sequence
+ test_value_2 = {
+ "test2": "test_simple_cache_hit_miss_partial2",
+ }
+ self.cache.update(seq, key, test_value_2, full=False)
+
+ c = self.cache.get(key)
+ self.assertEqual(
+ {
+ "test": "test_simple_cache_hit_miss_partial",
+ "test2": "test_simple_cache_hit_miss_partial2",
+ },
+ c.value
+ )
|