summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2014-09-02 17:57:04 +0100
committerMark Haines <mark.haines@matrix.org>2014-09-02 17:57:04 +0100
commitc7a7cdf7346c9268b1d4f483b31e1fdc39b6d7e0 (patch)
tree0981bb47e64f47f3a4e57b4010f5fcb4743824d0
parentTest ratelimiter (diff)
downloadsynapse-c7a7cdf7346c9268b1d4f483b31e1fdc39b6d7e0.tar.xz
Add ratelimiting function to basehandler
-rw-r--r--synapse/api/errors.py1
-rwxr-xr-xsynapse/app/homeserver.py1
-rw-r--r--synapse/config/homeserver.py4
-rw-r--r--synapse/handlers/_base.py17
-rw-r--r--synapse/server.py5
5 files changed, 27 insertions, 1 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 21ededc5ae..3f33ca5b92 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -28,6 +28,7 @@ class Codes(object):
     UNKNOWN = "M_UNKNOWN"
     NOT_FOUND = "M_NOT_FOUND"
     UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
+    LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
 
 
 class CodeMessageException(Exception):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 606c9c650d..8a7cd07fec 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -247,6 +247,7 @@ def setup():
         upload_dir=os.path.abspath("uploads"),
         db_name=config.database_path,
         tls_context_factory=tls_context_factory,
+        config=config,
     )
 
     hs.register_servlets()
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 18072e3196..a9aa4c735c 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,8 +17,10 @@ from .tls import TlsConfig
 from .server import ServerConfig
 from .logger import LoggingConfig
 from .database import DatabaseConfig
+from .ratelimiting import RatelimitConfig
 
-class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig):
+class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
+                       RatelimitConfig):
     pass
 
 if __name__=='__main__':
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index b37c8be964..dc1298366e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from twisted.internet import defer
+from synapse.api.errors import cs_error, Codes
 
 class BaseHandler(object):
 
@@ -25,8 +26,24 @@ class BaseHandler(object):
         self.room_lock = hs.get_room_lock_manager()
         self.state_handler = hs.get_state_handler()
         self.distributor = hs.get_distributor()
+        self.ratelimiter = hs.get_ratelimiter()
+        self.clock = hs.get_clock()
         self.hs = hs
 
+    def ratelimit(self, user_id):
+        time_now = self.clock.time()
+        allowed, time_allowed = self.ratelimiter.send_message(
+            user_id, time_now,
+            msg_rate_hz=self.hs.config.rc_messages_per_second,
+            burst_count=self.hs.config.rc_messsage_burst_count,
+        )
+        if not allowed:
+            raise cs_error(
+                "Limit exceeded",
+                Codes.M_LIMIT_EXCEEDED,
+                retry_after_ms=1000*(time_allowed - time_now),
+            )
+
 
 class BaseRoomHandler(BaseHandler):
 
diff --git a/synapse/server.py b/synapse/server.py
index 3e72b2bcd5..35e311a47d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -32,6 +32,7 @@ from synapse.util import Clock
 from synapse.util.distributor import Distributor
 from synapse.util.lockutils import LockManager
 from synapse.streams.events import EventSources
+from synapse.api.ratelimiting import Ratelimiter
 
 
 class BaseHomeServer(object):
@@ -73,6 +74,7 @@ class BaseHomeServer(object):
         'resource_for_web_client',
         'resource_for_content_repo',
         'event_sources',
+        'ratelimiter',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer):
     def build_event_sources(self):
         return EventSources(self)
 
+    def build_ratelimiter(self):
+        return Ratelimiter()
+
     def register_servlets(self):
         """ Register all servlets associated with this HomeServer.
         """