diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 3f33ca5b92..23ce0af277 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -40,10 +40,13 @@ class CodeMessageException(Exception):
self.code = code
self.msg = msg
+ def error_dict(self):
+ return cs_error(self.msg)
+
class SynapseError(CodeMessageException):
"""A base error which can be caught for all synapse events."""
- def __init__(self, code, msg, errcode=""):
+ def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
Args:
@@ -54,6 +57,11 @@ class SynapseError(CodeMessageException):
super(SynapseError, self).__init__(code, msg)
self.errcode = errcode
+ def error_dict(self):
+ return cs_error(
+ self.msg,
+ self.errcode,
+ )
class RoomError(SynapseError):
"""An error raised when a room event fails."""
@@ -92,13 +100,25 @@ class StoreError(SynapseError):
pass
-def cs_exception(exception):
- if isinstance(exception, SynapseError):
+class LimitExceededError(SynapseError):
+ """A client has sent too many requests and is being throttled.
+ """
+ def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
+ errcode=Codes.LIMIT_EXCEEDED):
+ super(LimitExceededError, self).__init__(code, msg, errcode)
+ self.retry_after_ms = retry_after_ms
+
+ def error_dict(self):
return cs_error(
- exception.msg,
- Codes.UNKNOWN if not exception.errcode else exception.errcode)
- elif isinstance(exception, CodeMessageException):
- return cs_error(exception.msg)
+ self.msg,
+ self.errcode,
+ retry_after_ms=self.retry_after_ms,
+ )
+
+
+def cs_exception(exception):
+ if isinstance(exception, CodeMessageException):
+ return exception.error_dict()
else:
logging.error("Unknown exception type: %s", type(exception))
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c150b60e07..935adea1ac 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
from twisted.internet import defer
-from synapse.api.errors import cs_error, Codes
+from synapse.api.errors import LimitExceededError
class BaseHandler(object):
@@ -38,9 +38,7 @@ class BaseHandler(object):
burst_count=self.hs.config.rc_message_burst_count,
)
if not allowed:
- raise cs_error(
- "Limit exceeded",
- Codes.LIMIT_EXCEEDED,
+ raise LimitExceededError(
retry_after_ms=1000*(time_allowed - time_now),
)
|