summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/_base.py34
-rw-r--r--synapse/handlers/message.py16
-rw-r--r--synapse/handlers/profile.py2
-rw-r--r--synapse/handlers/register.py7
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py9
-rw-r--r--synapse/storage/events.py62
-rw-r--r--synapse/storage/room.py36
-rw-r--r--synapse/storage/schema/delta/41/ratelimit.sql22
9 files changed, 146 insertions, 44 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index e83adc8339..faa5609c0c 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,7 +53,20 @@ class BaseHandler(object):
 
         self.event_builder_factory = hs.get_event_builder_factory()
 
-    def ratelimit(self, requester):
+    @defer.inlineCallbacks
+    def ratelimit(self, requester, update=True):
+        """Ratelimits requests.
+
+        Args:
+            requester (Requester)
+            update (bool): Whether to record that a request is being processed.
+                Set to False when doing multiple checks for one request (e.g.
+                to check up front if we would reject the request), and set to
+                True for the last call for a given request.
+
+        Raises:
+            LimitExceededError if the request should be ratelimited
+        """
         time_now = self.clock.time()
         user_id = requester.user.to_string()
 
@@ -67,10 +80,25 @@ class BaseHandler(object):
         if requester.app_service and not requester.app_service.is_rate_limited():
             return
 
+        # Check if there is a per user override in the DB.
+        override = yield self.store.get_ratelimit_for_user(user_id)
+        if override:
+            # If overriden with a null Hz then ratelimiting has been entirely
+            # disabled for the user
+            if not override.messages_per_second:
+                return
+
+            messages_per_second = override.messages_per_second
+            burst_count = override.burst_count
+        else:
+            messages_per_second = self.hs.config.rc_messages_per_second
+            burst_count = self.hs.config.rc_message_burst_count
+
         allowed, time_allowed = self.ratelimiter.send_message(
             user_id, time_now,
-            msg_rate_hz=self.hs.config.rc_messages_per_second,
-            burst_count=self.hs.config.rc_message_burst_count,
+            msg_rate_hz=messages_per_second,
+            burst_count=burst_count,
+            update=update,
         )
         if not allowed:
             raise LimitExceededError(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 57265c6d7d..196925edad 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
@@ -254,17 +254,7 @@ class MessageHandler(BaseHandler):
         # We check here if we are currently being rate limited, so that we
         # don't do unnecessary work. We check again just before we actually
         # send the event.
-        time_now = self.clock.time()
-        allowed, time_allowed = self.ratelimiter.send_message(
-            event.sender, time_now,
-            msg_rate_hz=self.hs.config.rc_messages_per_second,
-            burst_count=self.hs.config.rc_message_burst_count,
-            update=False,
-        )
-        if not allowed:
-            raise LimitExceededError(
-                retry_after_ms=int(1000 * (time_allowed - time_now)),
-            )
+        yield self.ratelimit(requester, update=False)
 
         user = UserID.from_string(event.sender)
 
@@ -499,7 +489,7 @@ class MessageHandler(BaseHandler):
         # We now need to go and hit out to wherever we need to hit out to.
 
         if ratelimit:
-            self.ratelimit(requester)
+            yield self.ratelimit(requester)
 
         try:
             yield self.auth.check_from_context(event, context)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 9bf638f818..7abee98dea 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
         if not self.hs.is_mine(user):
             return
 
-        self.ratelimit(requester)
+        yield self.ratelimit(requester)
 
         room_ids = yield self.store.get_rooms_for_user(
             user.to_string(),
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 03c6a85fc6..ee3a2269a8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -54,6 +54,13 @@ class RegistrationHandler(BaseHandler):
                 Codes.INVALID_USERNAME
             )
 
+        if not localpart:
+            raise SynapseError(
+                400,
+                "User ID cannot be empty",
+                Codes.INVALID_USERNAME
+            )
+
         if localpart[0] == '_':
             raise SynapseError(
                 400,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 99cb7db0db..d2a0d6520a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -75,7 +75,7 @@ class RoomCreationHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        self.ratelimit(requester)
+        yield self.ratelimit(requester)
 
         if "room_alias_name" in config:
             for wchar in string.whitespace:
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 6a7cd96ea5..1421c18152 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -21,7 +21,7 @@ from synapse.api.auth import get_access_token_from_request, has_access_token
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
 from synapse.http.servlet import (
-    RestServlet, parse_json_object_from_request, assert_params_in_request
+    RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
 )
 from synapse.util.msisdn import phone_number_to_msisdn
 
@@ -142,15 +142,14 @@ class UsernameAvailabilityRestServlet(RestServlet):
         )
 
     @defer.inlineCallbacks
-    def on_POST(self, request):
+    def on_GET(self, request):
         ip = self.hs.get_ip_from_request(request)
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             yield wait_deferred
 
-            body = parse_json_object_from_request(request)
-            assert_params_in_request(body, ['username'])
+            username = parse_string(request, "username", required=True)
 
-            yield self.registration_handler.check_username(body['username'])
+            yield self.registration_handler.check_username(username)
 
             defer.returnValue((200, {"available": True}))
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2ab44ceaa7..dbd63078c6 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -2033,6 +2033,8 @@ class EventsStore(SQLBaseStore):
         for event_id, state_key in event_rows:
             txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
+        logger.debug("[purge] Finding new backward extremities")
+
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
         txn.execute(
@@ -2045,6 +2047,8 @@ class EventsStore(SQLBaseStore):
         )
         new_backwards_extrems = txn.fetchall()
 
+        logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
         txn.execute(
             "DELETE FROM event_backward_extremities WHERE room_id = ?",
             (room_id,)
@@ -2059,6 +2063,8 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
+        logger.debug("[purge] finding redundant state groups")
+
         # Get all state groups that are only referenced by events that are
         # to be deleted.
         txn.execute(
@@ -2074,15 +2080,19 @@ class EventsStore(SQLBaseStore):
         )
 
         state_rows = txn.fetchall()
-        state_groups_to_delete = [sg for sg, in state_rows]
+
+        # make a set of the redundant state groups, so that we can look them up
+        # efficiently
+        state_groups_to_delete = set([sg for sg, in state_rows])
 
         # Now we get all the state groups that rely on these state groups
-        new_state_edges = []
-        chunks = [
-            state_groups_to_delete[i:i + 100]
-            for i in xrange(0, len(state_groups_to_delete), 100)
-        ]
-        for chunk in chunks:
+        logger.debug("[purge] finding state groups which depend on redundant"
+                     " state groups")
+        remaining_state_groups = []
+        for i in xrange(0, len(state_rows), 100):
+            chunk = [sg for sg, in state_rows[i:i + 100]]
+            # look for state groups whose prev_state_group is one we are about
+            # to delete
             rows = self._simple_select_many_txn(
                 txn,
                 table="state_group_edges",
@@ -2091,21 +2101,28 @@ class EventsStore(SQLBaseStore):
                 retcols=["state_group"],
                 keyvalues={},
             )
-            new_state_edges.extend(row["state_group"] for row in rows)
+            remaining_state_groups.extend(
+                row["state_group"] for row in rows
+
+                # exclude state groups we are about to delete: no point in
+                # updating them
+                if row["state_group"] not in state_groups_to_delete
+            )
 
-        # Now we turn the state groups that reference to-be-deleted state groups
-        # to non delta versions.
-        for new_state_edge in new_state_edges:
+        # Now we turn the state groups that reference to-be-deleted state
+        # groups to non delta versions.
+        for sg in remaining_state_groups:
+            logger.debug("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
-                txn, [new_state_edge], types=None
+                txn, [sg], types=None
             )
-            curr_state = curr_state[new_state_edge]
+            curr_state = curr_state[sg]
 
             self._simple_delete_txn(
                 txn,
                 table="state_groups_state",
                 keyvalues={
-                    "state_group": new_state_edge,
+                    "state_group": sg,
                 }
             )
 
@@ -2113,7 +2130,7 @@ class EventsStore(SQLBaseStore):
                 txn,
                 table="state_group_edges",
                 keyvalues={
-                    "state_group": new_state_edge,
+                    "state_group": sg,
                 }
             )
 
@@ -2122,7 +2139,7 @@ class EventsStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": new_state_edge,
+                        "state_group": sg,
                         "room_id": room_id,
                         "type": key[0],
                         "state_key": key[1],
@@ -2132,6 +2149,7 @@ class EventsStore(SQLBaseStore):
                 ],
             )
 
+        logger.debug("[purge] removing redundant state groups")
         txn.executemany(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             state_rows
@@ -2140,12 +2158,15 @@ class EventsStore(SQLBaseStore):
             "DELETE FROM state_groups WHERE id = ?",
             state_rows
         )
+
         # Delete all non-state
+        logger.debug("[purge] removing events from event_to_state_groups")
         txn.executemany(
             "DELETE FROM event_to_state_groups WHERE event_id = ?",
             [(event_id,) for event_id, _ in event_rows]
         )
 
+        logger.debug("[purge] updating room_depth")
         txn.execute(
             "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
             (topological_ordering, room_id,)
@@ -2171,16 +2192,15 @@ class EventsStore(SQLBaseStore):
             "event_signatures",
             "rejections",
         ):
+            logger.debug("[purge] removing non-state events from %s", table)
+
             txn.executemany(
                 "DELETE FROM %s WHERE event_id = ?" % (table,),
                 to_delete
             )
 
-        txn.executemany(
-            "DELETE FROM events WHERE event_id = ?",
-            to_delete
-        )
         # Mark all state and own events as outliers
+        logger.debug("[purge] marking events as outliers")
         txn.executemany(
             "UPDATE events SET outlier = ?"
             " WHERE event_id = ?",
@@ -2190,6 +2210,8 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
+        logger.debug("[purge] done")
+
     @defer.inlineCallbacks
     def is_event_after(self, event_id1, event_id2):
         """Returns True if event_id1 is after event_id2 in the stream
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index e4c56cc175..5d543652bb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 from ._base import SQLBaseStore
 from .engines import PostgresEngine, Sqlite3Engine
@@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple(
     ("ban_level", "kick_level", "redact_level",)
 )
 
+RatelimitOverride = collections.namedtuple(
+    "RatelimitOverride",
+    ("messages_per_second", "burst_count",)
+)
+
 
 class RoomStore(SQLBaseStore):
 
@@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore):
         return self.runInteraction(
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def get_ratelimit_for_user(self, user_id):
+        """Check if there are any overrides for ratelimiting for the given
+        user
+
+        Args:
+            user_id (str)
+
+        Returns:
+            RatelimitOverride if there is an override, else None. If the contents
+            of RatelimitOverride are None or 0 then ratelimitng has been
+            disabled for that user entirely.
+        """
+        row = yield self._simple_select_one(
+            table="ratelimit_override",
+            keyvalues={"user_id": user_id},
+            retcols=("messages_per_second", "burst_count"),
+            allow_none=True,
+            desc="get_ratelimit_for_user",
+        )
+
+        if row:
+            defer.returnValue(RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
+            ))
+        else:
+            defer.returnValue(None)
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql
new file mode 100644
index 0000000000..a194bf0238
--- /dev/null
+++ b/synapse/storage/schema/delta/41/ratelimit.sql
@@ -0,0 +1,22 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE ratelimit_override (
+    user_id TEXT NOT NULL,
+    messages_per_second BIGINT,
+    burst_count BIGINT
+);
+
+CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);