summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/__init__.py6
-rw-r--r--synapse/notifier.py32
-rw-r--r--synapse/storage/_base.py47
-rw-r--r--synapse/storage/events.py4
4 files changed, 62 insertions, 27 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 64e08223b0..e4495ccf12 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -46,9 +46,10 @@ def _event_dict_property(key):
 
 class EventBase(object):
     def __init__(self, event_dict, signatures={}, unsigned={},
-                 internal_metadata_dict={}):
+                 internal_metadata_dict={}, rejected_reason=None):
         self.signatures = signatures
         self.unsigned = unsigned
+        self.rejected_reason = rejected_reason
 
         self._event_dict = event_dict
 
@@ -109,7 +110,7 @@ class EventBase(object):
 
 
 class FrozenEvent(EventBase):
-    def __init__(self, event_dict, internal_metadata_dict={}):
+    def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
         event_dict = dict(event_dict)
 
         # Signatures is a dict of dicts, and this is faster than doing a
@@ -128,6 +129,7 @@ class FrozenEvent(EventBase):
             signatures=signatures,
             unsigned=unsigned,
             internal_metadata_dict=internal_metadata_dict,
+            rejected_reason=rejected_reason,
         )
 
     @staticmethod
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 12573f3f59..d750a6fcf7 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -59,8 +59,8 @@ class _NotificationListener(object):
         self.limit = limit
         self.timeout = timeout
         self.deferred = deferred
-
         self.rooms = rooms
+        self.timer = None
 
     def notified(self):
         return self.deferred.called
@@ -93,6 +93,13 @@ class _NotificationListener(object):
                 self.appservice, set()
             ).discard(self)
 
+        # Cancel the timeout for this notifer if one exists.
+        if self.timer is not None:
+            try:
+                notifier.clock.cancel_call_later(self.timer)
+            except:
+                logger.exception("Failed to cancel notifier timer")
+
 
 class Notifier(object):
     """ This class is responsible for notifying any listeners when there are
@@ -325,14 +332,20 @@ class Notifier(object):
             self._register_with_keys(listener[0])
 
         result = yield callback()
+        timer = [None]
+
         if timeout:
             timed_out = [False]
 
             def _timeout_listener():
                 timed_out[0] = True
+                timer[0] = None
                 listener[0].notify(self, [], from_token, from_token)
 
-            self.clock.call_later(timeout/1000., _timeout_listener)
+            # We create multiple notification listeners so we have to manage
+            # canceling the timeout ourselves.
+            timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
+
             while not result and not timed_out[0]:
                 yield deferred
                 deferred = defer.Deferred()
@@ -347,6 +360,12 @@ class Notifier(object):
                 self._register_with_keys(listener[0])
                 result = yield callback()
 
+        if timer[0] is not None:
+            try:
+                self.clock.cancel_call_later(timer[0])
+            except:
+                logger.exception("Failed to cancel notifer timer")
+
         defer.returnValue(result)
 
     def get_events_for(self, user, rooms, pagination_config, timeout):
@@ -385,6 +404,8 @@ class Notifier(object):
         def _timeout_listener():
             # TODO (erikj): We should probably set to_token to the current
             # max rather than reusing from_token.
+            # Remove the timer from the listener so we don't try to cancel it.
+            listener.timer = None
             listener.notify(
                 self,
                 [],
@@ -400,8 +421,11 @@ class Notifier(object):
         if not timeout:
             _timeout_listener()
         else:
-            self.clock.call_later(timeout/1000.0, _timeout_listener)
-
+            # Only add the timer if the listener hasn't been notified
+            if not listener.notified():
+                listener.timer = self.clock.call_later(
+                    timeout/1000.0, _timeout_listener
+                )
         return
 
     @log_function
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index badf9a5f40..c1bf98cdcb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -251,13 +251,11 @@ class SQLBaseStore(object):
         self._txn_perf_counters = PerformanceCounters()
         self._get_event_counters = PerformanceCounters()
 
-        self._get_event_cache = LruCache(hs.config.event_cache_size)
+        self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+                                      max_entries=hs.config.event_cache_size)
 
         self.database_engine = hs.database_engine
 
-        # Pretend the getEventCache is just another named cache
-        caches_by_name["*getEvent*"] = self._get_event_cache
-
         self._stream_id_gen = StreamIdGenerator()
         self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
@@ -768,6 +766,12 @@ class SQLBaseStore(object):
 
         return [e for e in events if e]
 
+    def _invalidate_get_event_cache(self, event_id):
+        for check_redacted in (False, True):
+            for get_prev_content in (False, True):
+                self._get_event_cache.invalidate(event_id, check_redacted,
+                                                 get_prev_content)
+
     def _get_event_txn(self, txn, event_id, check_redacted=True,
                        get_prev_content=False, allow_rejected=False):
 
@@ -778,16 +782,14 @@ class SQLBaseStore(object):
             sql_getevents_timer.inc_by(curr_time - last_time, desc)
             return curr_time
 
-        cache = self._get_event_cache.setdefault(event_id, {})
-
         try:
-            # Separate cache entries for each way to invoke _get_event_txn
-            ret = cache[(check_redacted, get_prev_content, allow_rejected)]
+            ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
 
-            cache_counter.inc_hits("*getEvent*")
-            return ret
+            if allow_rejected or not ret.rejected_reason:
+                return ret
+            else:
+                return None
         except KeyError:
-            cache_counter.inc_misses("*getEvent*")
             pass
         finally:
             start_time = update_counter("event_cache", start_time)
@@ -812,19 +814,22 @@ class SQLBaseStore(object):
 
         start_time = update_counter("select_event", start_time)
 
+        result = self._get_event_from_row_txn(
+            txn, internal_metadata, js, redacted,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            rejected_reason=rejected_reason,
+        )
+        self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
+
         if allow_rejected or not rejected_reason:
-            result = self._get_event_from_row_txn(
-                txn, internal_metadata, js, redacted,
-                check_redacted=check_redacted,
-                get_prev_content=get_prev_content,
-            )
-            cache[(check_redacted, get_prev_content, allow_rejected)] = result
             return result
         else:
             return None
 
     def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
-                                check_redacted=True, get_prev_content=False):
+                                check_redacted=True, get_prev_content=False,
+                                rejected_reason=None):
 
         start_time = time.time() * 1000
 
@@ -841,7 +846,11 @@ class SQLBaseStore(object):
         internal_metadata = json.loads(str(internal_metadata).decode("utf8"))
         start_time = update_counter("decode_internal", start_time)
 
-        ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
+        ev = FrozenEvent(
+            d,
+            internal_metadata_dict=internal_metadata,
+            rejected_reason=rejected_reason,
+        )
         start_time = update_counter("build_frozen_event", start_time)
 
         if check_redacted and redacted:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f066484c7e..a2e87c27ce 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -93,7 +93,7 @@ class EventsStore(SQLBaseStore):
                            current_state=None):
 
         # Remove the any existing cache entries for the event_id
-        self._get_event_cache.pop(event.event_id)
+        self._invalidate_get_event_cache(event.event_id)
 
         if stream_ordering is None:
             with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
@@ -356,7 +356,7 @@ class EventsStore(SQLBaseStore):
 
     def _store_redaction(self, txn, event):
         # invalidate the cache for the redacted event
-        self._get_event_cache.pop(event.redacts)
+        self._invalidate_get_event_cache(event.redacts)
         txn.execute(
             "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
             (event.event_id, event.redacts)