diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 259d5f6f88..d4457af950 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,6 +20,7 @@ from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
import collections
+import contextlib
import logging
@@ -112,16 +113,19 @@ class _PerHostRatelimiter(object):
or self.request_times
)
+ @contextlib.contextmanager
def ratelimit(self):
- request_id = object()
-
- def on_enter():
- return self._on_enter(request_id)
-
- def on_exit(exc_type, exc_val, exc_tb):
- return self._on_exit(request_id)
+ # `contextlib.contextmanager` takes a generator and turns it into a
+ # context manager. The generator should only yield once with a value
+ # to be returned by manager.
+ # Exceptions will be reraised at the yield.
- return ContextManagerFunction(on_enter, on_exit)
+ request_id = object()
+ ret = self._on_enter(request_id)
+ try:
+ yield ret
+ finally:
+ self._on_exit(request_id)
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
@@ -210,17 +214,3 @@ class _PerHostRatelimiter(object):
deferred.callback(None)
except KeyError:
pass
-
-
-class ContextManagerFunction(object):
- def __init__(self, on_enter, on_exit):
- self.on_enter = on_enter
- self.on_exit = on_exit
-
- def __enter__(self):
- if self.on_enter:
- return self.on_enter()
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.on_exit:
- return self.on_exit(exc_type, exc_val, exc_tb)
|