summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py32
-rw-r--r--synapse/util/async.py33
-rw-r--r--synapse/util/caches/dictionary_cache.py25
-rw-r--r--synapse/util/file_consumer.py16
-rw-r--r--synapse/util/logcontext.py15
-rw-r--r--synapse/util/ratelimitutils.py3
6 files changed, 70 insertions, 54 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index fc11e26623..2a3df7c71d 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,15 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import PreserveLoggingContext
-
-from twisted.internet import defer, reactor, task
-
-import time
 import logging
-
 from itertools import islice
 
+import attr
+from twisted.internet import defer, task
+
+from synapse.util.logcontext import PreserveLoggingContext
+
 logger = logging.getLogger(__name__)
 
 
@@ -31,16 +30,24 @@ def unwrapFirstError(failure):
     return failure.value.subFailure
 
 
+@attr.s
 class Clock(object):
-    """A small utility that obtains current time-of-day so that time may be
-    mocked during unit-tests.
-
-    TODO(paul): Also move the sleep() functionality into it
     """
+    A Clock wraps a Twisted reactor and provides utilities on top of it.
+    """
+    _reactor = attr.ib()
+
+    @defer.inlineCallbacks
+    def sleep(self, seconds):
+        d = defer.Deferred()
+        with PreserveLoggingContext():
+            self._reactor.callLater(seconds, d.callback, seconds)
+            res = yield d
+        defer.returnValue(res)
 
     def time(self):
         """Returns the current system time in seconds since epoch."""
-        return time.time()
+        return self._reactor.seconds()
 
     def time_msec(self):
         """Returns the current system time in miliseconds since epoch."""
@@ -56,6 +63,7 @@ class Clock(object):
             msec(float): How long to wait between calls in milliseconds.
         """
         call = task.LoopingCall(f)
+        call.clock = self._reactor
         call.start(msec / 1000.0, now=False)
         return call
 
@@ -73,7 +81,7 @@ class Clock(object):
                 callback(*args, **kwargs)
 
         with PreserveLoggingContext():
-            return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
+            return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 
     def cancel_call_later(self, timer, ignore_errs=False):
         try:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 9dd4e6b5bc..1668df4ce6 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -13,15 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer, reactor
+from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.python import failure
 
 from .logcontext import (
     PreserveLoggingContext, make_deferred_yieldable, run_in_background
 )
-from synapse.util import logcontext, unwrapFirstError
+from synapse.util import logcontext, unwrapFirstError, Clock
 
 from contextlib import contextmanager
 
@@ -32,22 +31,6 @@ from six.moves import range
 logger = logging.getLogger(__name__)
 
 
-@defer.inlineCallbacks
-def sleep(seconds):
-    d = defer.Deferred()
-    with PreserveLoggingContext():
-        reactor.callLater(seconds, d.callback, seconds)
-        res = yield d
-    defer.returnValue(res)
-
-
-def run_on_reactor():
-    """ This will cause the rest of the function to be invoked upon the next
-    iteration of the main loop
-    """
-    return sleep(0)
-
-
 class ObservableDeferred(object):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
@@ -180,13 +163,18 @@ class Linearizer(object):
             # do some work.
 
     """
-    def __init__(self, name=None):
+    def __init__(self, name=None, clock=None):
         if name is None:
             self.name = id(self)
         else:
             self.name = name
         self.key_to_defer = {}
 
+        if not clock:
+            from twisted.internet import reactor
+            clock = Clock(reactor)
+        self._clock = clock
+
     @defer.inlineCallbacks
     def queue(self, key):
         # If there is already a deferred in the queue, we pull it out so that
@@ -227,7 +215,7 @@ class Linearizer(object):
             # the context manager, but it needs to happen while we hold the
             # lock, and the context manager's exit code must be synchronous,
             # so actually this is the only sensible place.
-            yield run_on_reactor()
+            yield self._clock.sleep(0)
 
         else:
             logger.info("Acquired uncontended linearizer lock %r for key %r",
@@ -404,7 +392,7 @@ class DeferredTimeoutError(Exception):
     """
 
 
-def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     """
     Add a timeout to a deferred by scheduling it to be cancelled after
     timeout seconds.
@@ -419,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
     Args:
         deferred (defer.Deferred): deferred to be timed out
         timeout (Number): seconds to time out after
+        reactor (twisted.internet.reactor): the Twisted reactor to use
 
         on_timeout_cancel (callable): A callable which is called immediately
             after the deferred times out, and not if this deferred is
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index bdc21e348f..95793d466d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -107,29 +107,28 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False, known_absent=None):
+    def update(self, sequence, key, value, fetched_keys=None):
         """Updates the entry in the cache
 
         Args:
             sequence
-            key
-            value (dict): The value to update the cache with.
-            full (bool): Whether the given value is the full dict, or just a
-                partial subset there of. If not full then any existing entries
-                for the key will be updated.
-            known_absent (set): Set of keys that we know don't exist in the full
-                dict.
+            key (K)
+            value (dict[X,Y]): The value to update the cache with.
+            fetched_keys (None|set[X]): All of the dictionary keys which were
+                fetched from the database.
+
+                If None, this is the complete value for key K. Otherwise, it
+                is used to infer a list of keys which we know don't exist in
+                the full dict.
         """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            if known_absent is None:
-                known_absent = set()
-            if full:
-                self._insert(key, value, known_absent)
+            if fetched_keys is None:
+                self._insert(key, value, set())
             else:
-                self._update_or_insert(key, value, known_absent)
+                self._update_or_insert(key, value, fetched_keys)
 
     def _update_or_insert(self, key, value, known_absent):
         # We pop and reinsert as we need to tell the cache the size may have
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 3380970e4e..c78801015b 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import threads, reactor
+from twisted.internet import threads
 
 from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
@@ -27,6 +27,7 @@ class BackgroundFileConsumer(object):
     Args:
         file_obj (file): The file like object to write to. Closed when
             finished.
+        reactor (twisted.internet.reactor): the Twisted reactor to use
     """
 
     # For PushProducers pause if we have this many unwritten slices
@@ -34,9 +35,11 @@ class BackgroundFileConsumer(object):
     # And resume once the size of the queue is less than this
     _RESUME_ON_QUEUE_SIZE = 2
 
-    def __init__(self, file_obj):
+    def __init__(self, file_obj, reactor):
         self._file_obj = file_obj
 
+        self._reactor = reactor
+
         # Producer we're registered with
         self._producer = None
 
@@ -71,7 +74,10 @@ class BackgroundFileConsumer(object):
         self._producer = producer
         self.streaming = streaming
         self._finished_deferred = run_in_background(
-            threads.deferToThread, self._writer
+            threads.deferToThreadPool,
+            self._reactor,
+            self._reactor.getThreadPool(),
+            self._writer,
         )
         if not streaming:
             self._producer.resumeProducing()
@@ -109,7 +115,7 @@ class BackgroundFileConsumer(object):
                 # producer.
                 if self._producer and self._paused_producer:
                     if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
-                        reactor.callFromThread(self._resume_paused_producer)
+                        self._reactor.callFromThread(self._resume_paused_producer)
 
                 bytes = self._bytes_queue.get()
 
@@ -121,7 +127,7 @@ class BackgroundFileConsumer(object):
                 # If its a pull producer then we need to explicitly ask for
                 # more stuff.
                 if not self.streaming and self._producer:
-                    reactor.callFromThread(self._producer.resumeProducing)
+                    self._reactor.callFromThread(self._producer.resumeProducing)
         except Exception as e:
             self._write_exception = e
             raise
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a58c723403..df2b71b791 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -60,6 +60,7 @@ class LoggingContext(object):
     __slots__ = [
         "previous_context", "name", "ru_stime", "ru_utime",
         "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+        "evt_db_fetch_count",
         "usage_start",
         "main_thread", "alive",
         "request", "tag",
@@ -90,6 +91,9 @@ class LoggingContext(object):
         def add_database_scheduled(self, sched_sec):
             pass
 
+        def record_event_fetch(self, event_count):
+            pass
+
         def __nonzero__(self):
             return False
         __bool__ = __nonzero__  # python3
@@ -109,6 +113,9 @@ class LoggingContext(object):
         # sec spent waiting for db txns to be scheduled
         self.db_sched_duration_sec = 0
 
+        # number of events this thread has fetched from the db
+        self.evt_db_fetch_count = 0
+
         # If alive has the thread resource usage when the logcontext last
         # became active.
         self.usage_start = None
@@ -243,6 +250,14 @@ class LoggingContext(object):
         """
         self.db_sched_duration_sec += sched_sec
 
+    def record_event_fetch(self, event_count):
+        """Record a number of events being fetched from the db
+
+        Args:
+            event_count (int): number of events being fetched
+        """
+        self.evt_db_fetch_count += event_count
+
 
 class LoggingContextFilter(logging.Filter):
     """Logging filter that adds values from the current logging context to each
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 0ab63c3d7d..c5a45cef7c 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
 
 from synapse.api.errors import LimitExceededError
 
-from synapse.util.async import sleep
 from synapse.util.logcontext import (
     run_in_background, make_deferred_yieldable,
     PreserveLoggingContext,
@@ -153,7 +152,7 @@ class _PerHostRatelimiter(object):
                 "Ratelimit [%s]: sleeping req",
                 id(request_id),
             )
-            ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
+            ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
 
             self.sleeping_requests.add(request_id)