summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/ratelimitutils.py34
1 files changed, 12 insertions, 22 deletions
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 259d5f6f88..d4457af950 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,6 +20,7 @@ from synapse.api.errors import LimitExceededError
 from synapse.util.async import sleep
 
 import collections
+import contextlib
 import logging
 
 
@@ -112,16 +113,19 @@ class _PerHostRatelimiter(object):
             or self.request_times
         )
 
+    @contextlib.contextmanager
     def ratelimit(self):
-        request_id = object()
-
-        def on_enter():
-            return self._on_enter(request_id)
-
-        def on_exit(exc_type, exc_val, exc_tb):
-            return self._on_exit(request_id)
+        # `contextlib.contextmanager` takes a generator and turns it into a
+        # context manager. The generator should only yield once with a value
+        # to be returned by manager.
+        # Exceptions will be reraised at the yield.
 
-        return ContextManagerFunction(on_enter, on_exit)
+        request_id = object()
+        ret = self._on_enter(request_id)
+        try:
+            yield ret
+        finally:
+            self._on_exit(request_id)
 
     def _on_enter(self, request_id):
         time_now = self.clock.time_msec()
@@ -210,17 +214,3 @@ class _PerHostRatelimiter(object):
             deferred.callback(None)
         except KeyError:
             pass
-
-
-class ContextManagerFunction(object):
-    def __init__(self, on_enter, on_exit):
-        self.on_enter = on_enter
-        self.on_exit = on_exit
-
-    def __enter__(self):
-        if self.on_enter:
-            return self.on_enter()
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        if self.on_exit:
-            return self.on_exit(exc_type, exc_val, exc_tb)