summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py17
-rw-r--r--synapse/api/ratelimiting.py14
-rw-r--r--synapse/appservice/__init__.py7
-rw-r--r--synapse/config/appservice.py6
-rw-r--r--synapse/handlers/_base.py6
-rw-r--r--synapse/handlers/message.py17
-rw-r--r--synapse/rest/client/v2_alpha/filter.py12
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/types.py11
9 files changed, 66 insertions, 25 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1b3b55d517..b6a151a7ec 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -603,10 +603,12 @@ class Auth(object):
         """
         # Can optionally look elsewhere in the request (e.g. headers)
         try:
-            user_id = yield self._get_appservice_user_id(request)
+            user_id, app_service = yield self._get_appservice_user_id(request)
             if user_id:
                 request.authenticated_entity = user_id
-                defer.returnValue(synapse.types.create_requester(user_id))
+                defer.returnValue(
+                    synapse.types.create_requester(user_id, app_service=app_service)
+                )
 
             access_token = get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@@ -644,7 +646,8 @@ class Auth(object):
             request.authenticated_entity = user.to_string()
 
             defer.returnValue(synapse.types.create_requester(
-                user, token_id, is_guest, device_id))
+                user, token_id, is_guest, device_id, app_service=app_service)
+            )
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -659,14 +662,14 @@ class Auth(object):
             )
         )
         if app_service is None:
-            defer.returnValue(None)
+            defer.returnValue((None, None))
 
         if "user_id" not in request.args:
-            defer.returnValue(app_service.sender)
+            defer.returnValue((app_service.sender, app_service))
 
         user_id = request.args["user_id"][0]
         if app_service.sender == user_id:
-            defer.returnValue(app_service.sender)
+            defer.returnValue((app_service.sender, app_service))
 
         if not app_service.is_interested_in_user(user_id):
             raise AuthError(
@@ -678,7 +681,7 @@ class Auth(object):
                 403,
                 "Application service has not registered this user"
             )
-        defer.returnValue(user_id)
+        defer.returnValue((user_id, app_service))
 
     @defer.inlineCallbacks
     def get_user_by_access_token(self, token, rights="access"):
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 660dfb56e5..06cc8d90b8 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -23,7 +23,7 @@ class Ratelimiter(object):
     def __init__(self):
         self.message_counts = collections.OrderedDict()
 
-    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
+    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
         """Can the user send a message?
         Args:
             user_id: The user sending a message.
@@ -32,12 +32,15 @@ class Ratelimiter(object):
                 second.
             burst_count: How many messages the user can send before being
                 limited.
+            update (bool): Whether to update the message rates or not. This is
+                useful to check if a message would be allowed to be sent before
+                its ready to be actually sent.
         Returns:
             A pair of a bool indicating if they can send a message now and a
                 time in seconds of when they can next send a message.
         """
         self.prune_message_counts(time_now_s)
-        message_count, time_start, _ignored = self.message_counts.pop(
+        message_count, time_start, _ignored = self.message_counts.get(
             user_id, (0., time_now_s, None),
         )
         time_delta = time_now_s - time_start
@@ -52,9 +55,10 @@ class Ratelimiter(object):
             allowed = True
             message_count += 1
 
-        self.message_counts[user_id] = (
-            message_count, time_start, msg_rate_hz
-        )
+        if update:
+            self.message_counts[user_id] = (
+                message_count, time_start, msg_rate_hz
+            )
 
         if msg_rate_hz > 0:
             time_allowed = (
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 126a10efb7..91471f7e89 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -81,7 +81,7 @@ class ApplicationService(object):
     NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
 
     def __init__(self, token, url=None, namespaces=None, hs_token=None,
-                 sender=None, id=None, protocols=None):
+                 sender=None, id=None, protocols=None, rate_limited=True):
         self.token = token
         self.url = url
         self.hs_token = hs_token
@@ -95,6 +95,8 @@ class ApplicationService(object):
         else:
             self.protocols = set()
 
+        self.rate_limited = rate_limited
+
     def _check_namespaces(self, namespaces):
         # Sanity check that it is of the form:
         # {
@@ -234,5 +236,8 @@ class ApplicationService(object):
     def is_exclusive_room(self, room_id):
         return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
 
+    def is_rate_limited(self):
+        return self.rate_limited
+
     def __str__(self):
         return "ApplicationService: %s" % (self.__dict__,)
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index d7537e8d44..82c50b8240 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -110,6 +110,11 @@ def _load_appservice(hostname, as_info, config_filename):
     user = UserID(localpart, hostname)
     user_id = user.to_string()
 
+    # Rate limiting for users of this AS is on by default (excludes sender)
+    rate_limited = True
+    if isinstance(as_info.get("rate_limited"), bool):
+        rate_limited = as_info.get("rate_limited")
+
     # namespace checks
     if not isinstance(as_info.get("namespaces"), dict):
         raise KeyError("Requires 'namespaces' object.")
@@ -155,4 +160,5 @@ def _load_appservice(hostname, as_info, config_filename):
         sender=user_id,
         id=as_info["id"],
         protocols=protocols,
+        rate_limited=rate_limited
     )
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 4981643166..90f96209f8 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -57,10 +57,16 @@ class BaseHandler(object):
         time_now = self.clock.time()
         user_id = requester.user.to_string()
 
+        # The AS user itself is never rate limited.
         app_service = self.store.get_app_service_by_user_id(user_id)
         if app_service is not None:
             return  # do not ratelimit app service senders
 
+        # Disable rate limiting of users belonging to any AS that is configured
+        # not to be rate limited in its registration file (rate_limited: true|false).
+        if requester.app_service and not requester.app_service.is_rate_limited():
+            return
+
         allowed, time_allowed = self.ratelimiter.send_message(
             user_id, time_now,
             msg_rate_hz=self.hs.config.rc_messages_per_second,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 30ea9630f7..59eb26beaf 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
 from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
@@ -239,6 +239,21 @@ class MessageHandler(BaseHandler):
                 "Tried to send member event through non-member codepath"
             )
 
+        # We check here if we are currently being rate limited, so that we
+        # don't do unnecessary work. We check again just before we actually
+        # send the event.
+        time_now = self.clock.time()
+        allowed, time_allowed = self.ratelimiter.send_message(
+            event.sender, time_now,
+            msg_rate_hz=self.hs.config.rc_messages_per_second,
+            burst_count=self.hs.config.rc_message_burst_count,
+            update=False,
+        )
+        if not allowed:
+            raise LimitExceededError(
+                retry_after_ms=int(1000 * (time_allowed - time_now)),
+            )
+
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 510f8b2c74..b4084fec62 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, SynapseError, StoreError, Codes
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.types import UserID
 
@@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet):
             raise AuthError(403, "Cannot get filters for other users")
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only get filters for local users")
+            raise AuthError(403, "Can only get filters for local users")
 
         try:
             filter_id = int(filter_id)
@@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet):
             )
 
             defer.returnValue((200, filter.get_filter_json()))
-        except KeyError:
-            raise SynapseError(400, "No such filter")
+        except (KeyError, StoreError):
+            raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
 
 
 class CreateFilterRestServlet(RestServlet):
@@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id):
+
         target_user = UserID.from_string(user_id)
         requester = yield self.auth.get_user_by_req(request)
 
@@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet):
             raise AuthError(403, "Cannot create filters for other users")
 
         if not self.hs.is_mine(target_user):
-            raise SynapseError(400, "Can only create filters for local users")
+            raise AuthError(403, "Can only create filters for local users")
 
         content = parse_json_object_from_request(request)
-
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=target_user.localpart,
             user_filter=content,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 49fa8614f2..d828d6ee1d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -85,7 +85,6 @@ class LoggingTransaction(object):
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
         sql = self.database_engine.convert_param_style(sql)
-
         if args:
             try:
                 sql_logger.debug(
diff --git a/synapse/types.py b/synapse/types.py
index 1694af1250..ffab12df09 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,8 +18,9 @@ from synapse.api.errors import SynapseError
 from collections import namedtuple
 
 
-Requester = namedtuple("Requester",
-                       ["user", "access_token_id", "is_guest", "device_id"])
+Requester = namedtuple("Requester", [
+    "user", "access_token_id", "is_guest", "device_id", "app_service",
+])
 """
 Represents the user making a request
 
@@ -29,11 +30,12 @@ Attributes:
         request, or None if it came via the appservice API or similar
     is_guest (bool):  True if the user making this request is a guest user
     device_id (str|None):  device_id which was set at authentication time
+    app_service (ApplicationService|None):  the AS requesting on behalf of the user
 """
 
 
 def create_requester(user_id, access_token_id=None, is_guest=False,
-                     device_id=None):
+                     device_id=None, app_service=None):
     """
     Create a new ``Requester`` object
 
@@ -43,13 +45,14 @@ def create_requester(user_id, access_token_id=None, is_guest=False,
             request, or None if it came via the appservice API or similar
         is_guest (bool):  True if the user making this request is a guest user
         device_id (str|None):  device_id which was set at authentication time
+        app_service (ApplicationService|None):  the AS requesting on behalf of the user
 
     Returns:
         Requester
     """
     if not isinstance(user_id, UserID):
         user_id = UserID.from_string(user_id)
-    return Requester(user_id, access_token_id, is_guest, device_id)
+    return Requester(user_id, access_token_id, is_guest, device_id, app_service)
 
 
 def get_domain_from_id(string):