summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py2
-rw-r--r--synapse/util/caches/__init__.py57
-rw-r--r--synapse/util/caches/descriptors.py6
-rw-r--r--synapse/util/caches/expiringcache.py18
-rw-r--r--synapse/util/caches/lrucache.py73
-rw-r--r--synapse/util/caches/stream_change_cache.py20
-rw-r--r--synapse/util/wheel_timer.py97
7 files changed, 236 insertions, 37 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 133671e238..3b9da5b34a 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -42,7 +42,7 @@ class Clock(object):
 
     def time_msec(self):
         """Returns the current system time in miliseconds since epoch."""
-        return self.time() * 1000
+        return int(self.time() * 1000)
 
     def looping_call(self, f, msec):
         l = task.LoopingCall(f)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 1a14904194..d53569ca49 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -14,6 +14,10 @@
 # limitations under the License.
 
 import synapse.metrics
+from lrucache import LruCache
+import os
+
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 DEBUG_CACHES = False
 
@@ -25,3 +29,56 @@ cache_counter = metrics.register_cache(
     lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
     labels=["name"],
 )
+
+_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
+caches_by_name["string_cache"] = _string_cache
+
+
+KNOWN_KEYS = {
+    key: key for key in
+    (
+        "auth_events",
+        "content",
+        "depth",
+        "event_id",
+        "hashes",
+        "origin",
+        "origin_server_ts",
+        "prev_events",
+        "room_id",
+        "sender",
+        "signatures",
+        "state_key",
+        "type",
+        "unsigned",
+        "user_id",
+    )
+}
+
+
+def intern_string(string):
+    """Takes a (potentially) unicode string and interns using custom cache
+    """
+    return _string_cache.setdefault(string, string)
+
+
+def intern_dict(dictionary):
+    """Takes a dictionary and interns well known keys and their values
+    """
+    return {
+        KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
+        for key, value in dictionary.items()
+    }
+
+
+def _intern_known_values(key, value):
+    intern_str_keys = ("event_id", "room_id")
+    intern_unicode_keys = ("sender", "user_id", "type", "state_key")
+
+    if key in intern_str_keys:
+        return intern(value.encode('ascii'))
+
+    if key in intern_unicode_keys:
+        return intern_string(value)
+
+    return value
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 277854ccbc..35544b19fd 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -28,6 +28,7 @@ from twisted.internet import defer
 
 from collections import OrderedDict
 
+import os
 import functools
 import inspect
 import threading
@@ -38,6 +39,9 @@ logger = logging.getLogger(__name__)
 _CacheSentinel = object()
 
 
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
@@ -140,6 +144,8 @@ class CacheDescriptor(object):
     """
     def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
                  inlineCallbacks=False):
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
+
         self.orig = orig
 
         if inlineCallbacks:
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 62cae99649..2b68c1ac93 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.util.caches import cache_counter, caches_by_name
+
 import logging
 
 
@@ -47,6 +49,8 @@ class ExpiringCache(object):
 
         self._cache = {}
 
+        caches_by_name[cache_name] = self._cache
+
     def start(self):
         if not self._expiry_ms:
             # Don't bother starting the loop if things never expire
@@ -65,14 +69,19 @@ class ExpiringCache(object):
         if self._max_len and len(self._cache.keys()) > self._max_len:
             sorted_entries = sorted(
                 self._cache.items(),
-                key=lambda (k, v): v.time,
+                key=lambda item: item[1].time,
             )
 
             for k, _ in sorted_entries[self._max_len:]:
                 self._cache.pop(k)
 
     def __getitem__(self, key):
-        entry = self._cache[key]
+        try:
+            entry = self._cache[key]
+            cache_counter.inc_hits(self._cache_name)
+        except KeyError:
+            cache_counter.inc_misses(self._cache_name)
+            raise
 
         if self._reset_expiry_on_get:
             entry.time = self._clock.time_msec()
@@ -105,9 +114,12 @@ class ExpiringCache(object):
 
         logger.debug(
             "[%s] _prune_cache before: %d, after len: %d",
-            self._cache_name, begin_length, len(self._cache.keys())
+            self._cache_name, begin_length, len(self._cache)
         )
 
+    def __len__(self):
+        return len(self._cache)
+
 
 class _CacheEntry(object):
     def __init__(self, time, value):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f7423f2fab..f9df445a8d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -29,6 +29,16 @@ def enumerate_leaves(node, depth):
                 yield m
 
 
+class _Node(object):
+    __slots__ = ["prev_node", "next_node", "key", "value"]
+
+    def __init__(self, prev_node, next_node, key, value):
+        self.prev_node = prev_node
+        self.next_node = next_node
+        self.key = key
+        self.value = value
+
+
 class LruCache(object):
     """
     Least-recently-used cache.
@@ -38,10 +48,9 @@ class LruCache(object):
     def __init__(self, max_size, keylen=1, cache_type=dict):
         cache = cache_type()
         self.cache = cache  # Used for introspection.
-        list_root = []
-        list_root[:] = [list_root, list_root, None, None]
-
-        PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
+        list_root = _Node(None, None, None, None)
+        list_root.next_node = list_root
+        list_root.prev_node = list_root
 
         lock = threading.Lock()
 
@@ -55,36 +64,36 @@ class LruCache(object):
 
         def add_node(key, value):
             prev_node = list_root
-            next_node = prev_node[NEXT]
-            node = [prev_node, next_node, key, value]
-            prev_node[NEXT] = node
-            next_node[PREV] = node
+            next_node = prev_node.next_node
+            node = _Node(prev_node, next_node, key, value)
+            prev_node.next_node = node
+            next_node.prev_node = node
             cache[key] = node
 
         def move_node_to_front(node):
-            prev_node = node[PREV]
-            next_node = node[NEXT]
-            prev_node[NEXT] = next_node
-            next_node[PREV] = prev_node
+            prev_node = node.prev_node
+            next_node = node.next_node
+            prev_node.next_node = next_node
+            next_node.prev_node = prev_node
             prev_node = list_root
-            next_node = prev_node[NEXT]
-            node[PREV] = prev_node
-            node[NEXT] = next_node
-            prev_node[NEXT] = node
-            next_node[PREV] = node
+            next_node = prev_node.next_node
+            node.prev_node = prev_node
+            node.next_node = next_node
+            prev_node.next_node = node
+            next_node.prev_node = node
 
         def delete_node(node):
-            prev_node = node[PREV]
-            next_node = node[NEXT]
-            prev_node[NEXT] = next_node
-            next_node[PREV] = prev_node
+            prev_node = node.prev_node
+            next_node = node.next_node
+            prev_node.next_node = next_node
+            next_node.prev_node = prev_node
 
         @synchronized
         def cache_get(key, default=None):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                return node[VALUE]
+                return node.value
             else:
                 return default
 
@@ -93,25 +102,25 @@ class LruCache(object):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                node[VALUE] = value
+                node.value = value
             else:
                 add_node(key, value)
                 if len(cache) > max_size:
-                    todelete = list_root[PREV]
+                    todelete = list_root.prev_node
                     delete_node(todelete)
-                    cache.pop(todelete[KEY], None)
+                    cache.pop(todelete.key, None)
 
         @synchronized
         def cache_set_default(key, value):
             node = cache.get(key, None)
             if node is not None:
-                return node[VALUE]
+                return node.value
             else:
                 add_node(key, value)
                 if len(cache) > max_size:
-                    todelete = list_root[PREV]
+                    todelete = list_root.prev_node
                     delete_node(todelete)
-                    cache.pop(todelete[KEY], None)
+                    cache.pop(todelete.key, None)
                 return value
 
         @synchronized
@@ -119,8 +128,8 @@ class LruCache(object):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
-                cache.pop(node[KEY], None)
-                return node[VALUE]
+                cache.pop(node.key, None)
+                return node.value
             else:
                 return default
 
@@ -137,8 +146,8 @@ class LruCache(object):
 
         @synchronized
         def cache_clear():
-            list_root[NEXT] = list_root
-            list_root[PREV] = list_root
+            list_root.next_node = list_root
+            list_root.prev_node = list_root
             cache.clear()
 
         @synchronized
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index b37f1c0725..ea8a74ca69 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -18,11 +18,15 @@ from synapse.util.caches import cache_counter, caches_by_name
 
 from blist import sorteddict
 import logging
+import os
 
 
 logger = logging.getLogger(__name__)
 
 
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
 class StreamChangeCache(object):
     """Keeps track of the stream positions of the latest change in a set of entities.
 
@@ -33,7 +37,7 @@ class StreamChangeCache(object):
     old then the cache will simply return all given entities.
     """
     def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
-        self._max_size = max_size
+        self._max_size = int(max_size * CACHE_SIZE_FACTOR)
         self._entity_to_key = {}
         self._cache = sorteddict()
         self._earliest_known_stream_pos = current_stream_pos
@@ -85,6 +89,20 @@ class StreamChangeCache(object):
 
         return result
 
+    def get_all_entities_changed(self, stream_pos):
+        """Returns all entites that have had new things since the given
+        position. If the position is too old it will return None.
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos >= self._earliest_known_stream_pos:
+            keys = self._cache.keys()
+            i = keys.bisect_right(stream_pos)
+
+            return [self._cache[k] for k in keys[i:]]
+        else:
+            return None
+
     def entity_has_changed(self, entity, stream_pos):
         """Informs the cache that the entity has been changed at the given
         position.
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
new file mode 100644
index 0000000000..7412fc57a4
--- /dev/null
+++ b/synapse/util/wheel_timer.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class _Entry(object):
+    __slots__ = ["end_key", "queue"]
+
+    def __init__(self, end_key):
+        self.end_key = end_key
+        self.queue = []
+
+
+class WheelTimer(object):
+    """Stores arbitrary objects that will be returned after their timers have
+    expired.
+    """
+
+    def __init__(self, bucket_size=5000):
+        """
+        Args:
+            bucket_size (int): Size of buckets in ms. Corresponds roughly to the
+                accuracy of the timer.
+        """
+        self.bucket_size = bucket_size
+        self.entries = []
+        self.current_tick = 0
+
+    def insert(self, now, obj, then):
+        """Inserts object into timer.
+
+        Args:
+            now (int): Current time in msec
+            obj (object): Object to be inserted
+            then (int): When to return the object strictly after.
+        """
+        then_key = int(then / self.bucket_size) + 1
+
+        if self.entries:
+            min_key = self.entries[0].end_key
+            max_key = self.entries[-1].end_key
+
+            if then_key <= max_key:
+                # The max here is to protect against inserts for times in the past
+                self.entries[max(min_key, then_key) - min_key].queue.append(obj)
+                return
+
+        next_key = int(now / self.bucket_size) + 1
+        if self.entries:
+            last_key = self.entries[-1].end_key
+        else:
+            last_key = next_key
+
+        # Handle the case when `then` is in the past and `entries` is empty.
+        then_key = max(last_key, then_key)
+
+        # Add empty entries between the end of the current list and when we want
+        # to insert. This ensures there are no gaps.
+        self.entries.extend(
+            _Entry(key) for key in xrange(last_key, then_key + 1)
+        )
+
+        self.entries[-1].queue.append(obj)
+
+    def fetch(self, now):
+        """Fetch any objects that have timed out
+
+        Args:
+            now (ms): Current time in msec
+
+        Returns:
+            list: List of objects that have timed out
+        """
+        now_key = int(now / self.bucket_size)
+
+        ret = []
+        while self.entries and self.entries[0].end_key <= now_key:
+            ret.extend(self.entries.pop(0).queue)
+
+        return ret
+
+    def __len__(self):
+        l = 0
+        for entry in self.entries:
+            l += len(entry.queue)
+        return l